diff --git a/test/dynamo/cpython/3_13/test_with.diff b/test/dynamo/cpython/3_13/test_with.diff index 696fefb91edf..29d0550c419f 100644 --- a/test/dynamo/cpython/3_13/test_with.diff +++ b/test/dynamo/cpython/3_13/test_with.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_with.py b/test/dynamo/cpython/3_13/test_with.py -index 8e9ed8500c7..e1ebaa68b83 100644 +index 8e9ed8500c7..66c18ad886a 100644 --- a/test/dynamo/cpython/3_13/test_with.py +++ b/test/dynamo/cpython/3_13/test_with.py @@ -1,3 +1,23 @@ @@ -26,7 +26,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644 """Unit tests for the with statement specified in PEP 343.""" -@@ -104,7 +124,7 @@ class MockNested(Nested): +@@ -104,16 +124,17 @@ class MockNested(Nested): return Nested.__exit__(self, *exc_info) @@ -35,7 +35,82 @@ index 8e9ed8500c7..e1ebaa68b83 100644 def testNameError(self): def fooNotDeclared(): with foo: pass -@@ -194,6 +214,7 @@ class ContextmanagerAssertionMixin(object): + self.assertRaises(NameError, fooNotDeclared) + + def testEnterAttributeError1(self): +- class LacksEnter(object): +- def __exit__(self, type, value, traceback): +- pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class LacksEnter(object): ++ def __exit__(self, type, value, traceback): ++ pass + + def fooLacksEnter(): + foo = LacksEnter() +@@ -121,8 +142,9 @@ class FailureTestCase(unittest.TestCase): + self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter) + + def testEnterAttributeError2(self): +- class LacksEnterAndExit(object): +- pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class LacksEnterAndExit(object): ++ pass + + def fooLacksEnterAndExit(): + foo = LacksEnterAndExit() +@@ -130,9 +152,10 @@ class FailureTestCase(unittest.TestCase): + self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit) + + def testExitAttributeError(self): +- class LacksExit(object): +- def __enter__(self): +- pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class LacksExit(object): ++ def __enter__(self): ++ pass + + def fooLacksExit(): + foo = LacksExit() +@@ -162,11 +185,12 @@ class FailureTestCase(unittest.TestCase): + ' pass') + + def testEnterThrows(self): +- class EnterThrows(object): +- def __enter__(self): +- raise RuntimeError("Enter threw") +- def __exit__(self, *args): +- pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class EnterThrows(object): ++ def __enter__(self): ++ raise RuntimeError("Enter threw") ++ def __exit__(self, *args): ++ pass + + def shouldThrow(): + ct = EnterThrows() +@@ -180,11 +204,12 @@ class FailureTestCase(unittest.TestCase): + self.assertEqual(self.foo, None) + + def testExitThrows(self): +- class ExitThrows(object): +- def __enter__(self): +- return +- def __exit__(self, *args): +- raise RuntimeError(42) ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class ExitThrows(object): ++ def __enter__(self): ++ return ++ def __exit__(self, *args): ++ raise RuntimeError(42) + def shouldThrow(): + with ExitThrows(): + pass +@@ -194,6 +219,7 @@ class ContextmanagerAssertionMixin(object): def setUp(self): self.TEST_EXCEPTION = RuntimeError("test exception") @@ -43,7 +118,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644 def assertInWithManagerInvariants(self, mock_manager): self.assertTrue(mock_manager.enter_called) -@@ -237,7 +258,7 @@ class ContextmanagerAssertionMixin(object): +@@ -237,7 +263,7 @@ class ContextmanagerAssertionMixin(object): self.assertTrue(mock_generator.stopped) @@ -52,7 +127,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644 def testInlineGeneratorSyntax(self): with mock_contextmanager_generator(): pass -@@ -289,7 +310,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): +@@ -289,7 +315,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): self.assertAfterWithGeneratorInvariantsNoError(foo) @@ -61,7 +136,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644 ContextmanagerAssertionMixin): def testSingleArgInlineGeneratorSyntax(self): with Nested(mock_contextmanager_generator()): -@@ -355,7 +376,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase, +@@ -355,7 +381,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase, self.assertAfterWithManagerInvariantsNoError(mock_nested) @@ -70,7 +145,71 @@ index 8e9ed8500c7..e1ebaa68b83 100644 def testSingleResource(self): cm = mock_contextmanager_generator() def shouldThrow(): -@@ -550,7 +571,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): +@@ -466,11 +492,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): + + def testRaisedStopIteration2(self): + # From bug 1462485 +- class cm(object): +- def __enter__(self): +- pass +- def __exit__(self, type, value, traceback): +- pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class cm(object): ++ def __enter__(self): ++ pass ++ def __exit__(self, type, value, traceback): ++ pass + + def shouldThrow(): + with cm(): +@@ -507,11 +534,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): + + def testRaisedGeneratorExit2(self): + # From bug 1462485 +- class cm (object): +- def __enter__(self): +- pass +- def __exit__(self, type, value, traceback): +- pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class cm (object): ++ def __enter__(self): ++ pass ++ def __exit__(self, type, value, traceback): ++ pass + + def shouldThrow(): + with cm(): +@@ -523,16 +551,17 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): + # issue4589: __exit__ return code may raise an exception + # when looking at its truth value. + +- class cm(object): +- def __init__(self, bool_conversion): +- class Bool: +- def __bool__(self): +- return bool_conversion() +- self.exit_result = Bool() +- def __enter__(self): +- return 3 +- def __exit__(self, a, b, c): +- return self.exit_result ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class cm(object): ++ def __init__(self, bool_conversion): ++ class Bool: ++ def __bool__(self): ++ return bool_conversion() ++ self.exit_result = Bool() ++ def __enter__(self): ++ return 3 ++ def __exit__(self, a, b, c): ++ return self.exit_result + + def trueAsBool(): + with cm(lambda: True): +@@ -550,7 +579,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): self.assertRaises(ZeroDivisionError, failAsBool) @@ -79,7 +218,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644 def testWithBreak(self): counter = 0 -@@ -607,7 +628,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase): +@@ -607,7 +636,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase): self.fail("Didn't raise RuntimeError") @@ -88,7 +227,39 @@ index 8e9ed8500c7..e1ebaa68b83 100644 def testSingleComplexTarget(self): targets = {1: [0, 1, 2]} -@@ -651,7 +672,7 @@ class AssignmentTargetTestCase(unittest.TestCase): +@@ -621,15 +650,17 @@ class AssignmentTargetTestCase(unittest.TestCase): + keys = list(targets.keys()) + keys.sort() + self.assertEqual(keys, [1, 2]) +- class C: pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class C: pass + blah = C() + with mock_contextmanager_generator() as blah.foo: + self.assertEqual(hasattr(blah, "foo"), True) + + def testMultipleComplexTargets(self): +- class C: +- def __enter__(self): return 1, 2, 3 +- def __exit__(self, t, v, tb): pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class C: ++ def __enter__(self): return 1, 2, 3 ++ def __exit__(self, t, v, tb): pass + targets = {1: [0, 1, 2]} + with C() as (targets[1][0], targets[1][1], targets[1][2]): + self.assertEqual(targets, {1: [1, 2, 3]}) +@@ -637,7 +668,8 @@ class AssignmentTargetTestCase(unittest.TestCase): + self.assertEqual(targets, {1: [3, 2, 1]}) + with C() as (targets[1], targets[2], targets[3]): + self.assertEqual(targets, {1: 1, 2: 2, 3: 3}) +- class B: pass ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class B: pass + blah = B() + with C() as (blah.one, blah.two, blah.three): + self.assertEqual(blah.one, 1) +@@ -651,12 +683,13 @@ class AssignmentTargetTestCase(unittest.TestCase): self.assertEqual(c, 4) @@ -96,8 +267,31 @@ index 8e9ed8500c7..e1ebaa68b83 100644 +class ExitSwallowsExceptionTestCase(__TestCase): def testExitTrueSwallowsException(self): - class AfricanSwallow: -@@ -676,7 +697,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase): +- class AfricanSwallow: +- def __enter__(self): pass +- def __exit__(self, t, v, tb): return True ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class AfricanSwallow: ++ def __enter__(self): pass ++ def __exit__(self, t, v, tb): return True + try: + with AfricanSwallow(): + 1/0 +@@ -664,9 +697,10 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase): + self.fail("ZeroDivisionError should have been swallowed") + + def testExitFalseDoesntSwallowException(self): +- class EuropeanSwallow: +- def __enter__(self): pass +- def __exit__(self, t, v, tb): return False ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ class EuropeanSwallow: ++ def __enter__(self): pass ++ def __exit__(self, t, v, tb): return False + try: + with EuropeanSwallow(): + 1/0 +@@ -676,7 +710,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase): self.fail("ZeroDivisionError should have been raised") @@ -106,7 +300,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644 class Dummy(object): def __init__(self, value=None, gobble=False): -@@ -796,4 +817,4 @@ class NestedWith(unittest.TestCase): +@@ -796,4 +830,4 @@ class NestedWith(unittest.TestCase): if __name__ == '__main__': diff --git a/test/dynamo/cpython/3_13/test_with.py b/test/dynamo/cpython/3_13/test_with.py index e1ebaa68b839..66c18ad886aa 100644 --- a/test/dynamo/cpython/3_13/test_with.py +++ b/test/dynamo/cpython/3_13/test_with.py @@ -131,9 +131,10 @@ def fooNotDeclared(): self.assertRaises(NameError, fooNotDeclared) def testEnterAttributeError1(self): - class LacksEnter(object): - def __exit__(self, type, value, traceback): - pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class LacksEnter(object): + def __exit__(self, type, value, traceback): + pass def fooLacksEnter(): foo = LacksEnter() @@ -141,8 +142,9 @@ def fooLacksEnter(): self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter) def testEnterAttributeError2(self): - class LacksEnterAndExit(object): - pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class LacksEnterAndExit(object): + pass def fooLacksEnterAndExit(): foo = LacksEnterAndExit() @@ -150,9 +152,10 @@ def fooLacksEnterAndExit(): self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit) def testExitAttributeError(self): - class LacksExit(object): - def __enter__(self): - pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class LacksExit(object): + def __enter__(self): + pass def fooLacksExit(): foo = LacksExit() @@ -182,11 +185,12 @@ def testAssignmentToTupleContainingNoneError(self): ' pass') def testEnterThrows(self): - class EnterThrows(object): - def __enter__(self): - raise RuntimeError("Enter threw") - def __exit__(self, *args): - pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class EnterThrows(object): + def __enter__(self): + raise RuntimeError("Enter threw") + def __exit__(self, *args): + pass def shouldThrow(): ct = EnterThrows() @@ -200,11 +204,12 @@ def shouldThrow(): self.assertEqual(self.foo, None) def testExitThrows(self): - class ExitThrows(object): - def __enter__(self): - return - def __exit__(self, *args): - raise RuntimeError(42) + with torch._dynamo.set_fullgraph(fullgraph=False): + class ExitThrows(object): + def __enter__(self): + return + def __exit__(self, *args): + raise RuntimeError(42) def shouldThrow(): with ExitThrows(): pass @@ -487,11 +492,12 @@ def shouldThrow(): def testRaisedStopIteration2(self): # From bug 1462485 - class cm(object): - def __enter__(self): - pass - def __exit__(self, type, value, traceback): - pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class cm(object): + def __enter__(self): + pass + def __exit__(self, type, value, traceback): + pass def shouldThrow(): with cm(): @@ -528,11 +534,12 @@ def shouldThrow(): def testRaisedGeneratorExit2(self): # From bug 1462485 - class cm (object): - def __enter__(self): - pass - def __exit__(self, type, value, traceback): - pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class cm (object): + def __enter__(self): + pass + def __exit__(self, type, value, traceback): + pass def shouldThrow(): with cm(): @@ -544,16 +551,17 @@ def testErrorsInBool(self): # issue4589: __exit__ return code may raise an exception # when looking at its truth value. - class cm(object): - def __init__(self, bool_conversion): - class Bool: - def __bool__(self): - return bool_conversion() - self.exit_result = Bool() - def __enter__(self): - return 3 - def __exit__(self, a, b, c): - return self.exit_result + with torch._dynamo.set_fullgraph(fullgraph=False): + class cm(object): + def __init__(self, bool_conversion): + class Bool: + def __bool__(self): + return bool_conversion() + self.exit_result = Bool() + def __enter__(self): + return 3 + def __exit__(self, a, b, c): + return self.exit_result def trueAsBool(): with cm(lambda: True): @@ -642,15 +650,17 @@ def testSingleComplexTarget(self): keys = list(targets.keys()) keys.sort() self.assertEqual(keys, [1, 2]) - class C: pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class C: pass blah = C() with mock_contextmanager_generator() as blah.foo: self.assertEqual(hasattr(blah, "foo"), True) def testMultipleComplexTargets(self): - class C: - def __enter__(self): return 1, 2, 3 - def __exit__(self, t, v, tb): pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class C: + def __enter__(self): return 1, 2, 3 + def __exit__(self, t, v, tb): pass targets = {1: [0, 1, 2]} with C() as (targets[1][0], targets[1][1], targets[1][2]): self.assertEqual(targets, {1: [1, 2, 3]}) @@ -658,7 +668,8 @@ def __exit__(self, t, v, tb): pass self.assertEqual(targets, {1: [3, 2, 1]}) with C() as (targets[1], targets[2], targets[3]): self.assertEqual(targets, {1: 1, 2: 2, 3: 3}) - class B: pass + with torch._dynamo.set_fullgraph(fullgraph=False): + class B: pass blah = B() with C() as (blah.one, blah.two, blah.three): self.assertEqual(blah.one, 1) @@ -675,9 +686,10 @@ def testWithExtendedTargets(self): class ExitSwallowsExceptionTestCase(__TestCase): def testExitTrueSwallowsException(self): - class AfricanSwallow: - def __enter__(self): pass - def __exit__(self, t, v, tb): return True + with torch._dynamo.set_fullgraph(fullgraph=False): + class AfricanSwallow: + def __enter__(self): pass + def __exit__(self, t, v, tb): return True try: with AfricanSwallow(): 1/0 @@ -685,9 +697,10 @@ def __exit__(self, t, v, tb): return True self.fail("ZeroDivisionError should have been swallowed") def testExitFalseDoesntSwallowException(self): - class EuropeanSwallow: - def __enter__(self): pass - def __exit__(self, t, v, tb): return False + with torch._dynamo.set_fullgraph(fullgraph=False): + class EuropeanSwallow: + def __enter__(self): pass + def __exit__(self, t, v, tb): return False try: with EuropeanSwallow(): 1/0 diff --git a/test/dynamo_expected_failures/CPython313-test_with-AssignmentTargetTestCase.testMultipleComplexTargets b/test/dynamo_expected_failures/CPython313-test_with-AssignmentTargetTestCase.testMultipleComplexTargets deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-AssignmentTargetTestCase.testSingleComplexTarget b/test/dynamo_expected_failures/CPython313-test_with-AssignmentTargetTestCase.testSingleComplexTarget deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-AssignmentTargetTestCase.testWithExtendedTargets b/test/dynamo_expected_failures/CPython313-test_with-AssignmentTargetTestCase.testWithExtendedTargets deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-ExceptionalTestCase.testErrorsInBool b/test/dynamo_expected_failures/CPython313-test_with-ExceptionalTestCase.testErrorsInBool deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-ExceptionalTestCase.testRaisedGeneratorExit2 b/test/dynamo_expected_failures/CPython313-test_with-ExceptionalTestCase.testRaisedGeneratorExit2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-ExceptionalTestCase.testRaisedStopIteration2 b/test/dynamo_expected_failures/CPython313-test_with-ExceptionalTestCase.testRaisedStopIteration2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-ExitSwallowsExceptionTestCase.testExitFalseDoesntSwallowException b/test/dynamo_expected_failures/CPython313-test_with-ExitSwallowsExceptionTestCase.testExitFalseDoesntSwallowException deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-ExitSwallowsExceptionTestCase.testExitTrueSwallowsException b/test/dynamo_expected_failures/CPython313-test_with-ExitSwallowsExceptionTestCase.testExitTrueSwallowsException deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testEnterAttributeError1 b/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testEnterAttributeError1 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testEnterAttributeError2 b/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testEnterAttributeError2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testEnterThrows b/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testEnterThrows deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testExitAttributeError b/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testExitAttributeError deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testExitThrows b/test/dynamo_expected_failures/CPython313-test_with-FailureTestCase.testExitThrows deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 09c9f4b7b727..5c1a569224c7 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -940,7 +940,8 @@ def __init__(self, target_values=None, **kwargs) -> None: super().__init__(target_values=target_values, **kwargs) def enter(self, tx): - return variables.ConstantVariable.create(None) + none = variables.ConstantVariable.create(None) + return self.target_values if self.target_values else none def exit(self, tx: "InstructionTranslator", *args): return variables.ConstantVariable.create(None) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 7cb21ab37280..683d82f69405 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -470,7 +470,7 @@ def call_function( # import here to avoid circular dependency from .ctx_manager import NullContextVariable - return NullContextVariable() + return NullContextVariable(*args, **kwargs) elif self.value is collections.OrderedDict: return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.construct_dict),