Skip to content

Update nullcontext to return input args #158776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/guilhermeleobas/210/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 206 additions & 12 deletions test/dynamo/cpython/3_13/test_with.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_with.py b/test/dynamo/cpython/3_13/test_with.py
index 8e9ed8500c7..e1ebaa68b83 100644
index 8e9ed8500c7..66c18ad886a 100644
--- a/test/dynamo/cpython/3_13/test_with.py
+++ b/test/dynamo/cpython/3_13/test_with.py
@@ -1,3 +1,23 @@
Expand All @@ -26,7 +26,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644
"""Unit tests for the with statement specified in PEP 343."""


@@ -104,7 +124,7 @@ class MockNested(Nested):
@@ -104,16 +124,17 @@ class MockNested(Nested):
return Nested.__exit__(self, *exc_info)


Expand All @@ -35,15 +35,90 @@ index 8e9ed8500c7..e1ebaa68b83 100644
def testNameError(self):
def fooNotDeclared():
with foo: pass
@@ -194,6 +214,7 @@ class ContextmanagerAssertionMixin(object):
self.assertRaises(NameError, fooNotDeclared)

def testEnterAttributeError1(self):
- class LacksEnter(object):
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class LacksEnter(object):
+ def __exit__(self, type, value, traceback):
+ pass

def fooLacksEnter():
foo = LacksEnter()
@@ -121,8 +142,9 @@ class FailureTestCase(unittest.TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)

def testEnterAttributeError2(self):
- class LacksEnterAndExit(object):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class LacksEnterAndExit(object):
+ pass

def fooLacksEnterAndExit():
foo = LacksEnterAndExit()
@@ -130,9 +152,10 @@ class FailureTestCase(unittest.TestCase):
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)

def testExitAttributeError(self):
- class LacksExit(object):
- def __enter__(self):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class LacksExit(object):
+ def __enter__(self):
+ pass

def fooLacksExit():
foo = LacksExit()
@@ -162,11 +185,12 @@ class FailureTestCase(unittest.TestCase):
' pass')

def testEnterThrows(self):
- class EnterThrows(object):
- def __enter__(self):
- raise RuntimeError("Enter threw")
- def __exit__(self, *args):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class EnterThrows(object):
+ def __enter__(self):
+ raise RuntimeError("Enter threw")
+ def __exit__(self, *args):
+ pass

def shouldThrow():
ct = EnterThrows()
@@ -180,11 +204,12 @@ class FailureTestCase(unittest.TestCase):
self.assertEqual(self.foo, None)

def testExitThrows(self):
- class ExitThrows(object):
- def __enter__(self):
- return
- def __exit__(self, *args):
- raise RuntimeError(42)
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class ExitThrows(object):
+ def __enter__(self):
+ return
+ def __exit__(self, *args):
+ raise RuntimeError(42)
def shouldThrow():
with ExitThrows():
pass
@@ -194,6 +219,7 @@ class ContextmanagerAssertionMixin(object):

def setUp(self):
self.TEST_EXCEPTION = RuntimeError("test exception")
+ super().setUp()

def assertInWithManagerInvariants(self, mock_manager):
self.assertTrue(mock_manager.enter_called)
@@ -237,7 +258,7 @@ class ContextmanagerAssertionMixin(object):
@@ -237,7 +263,7 @@ class ContextmanagerAssertionMixin(object):
self.assertTrue(mock_generator.stopped)


Expand All @@ -52,7 +127,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644
def testInlineGeneratorSyntax(self):
with mock_contextmanager_generator():
pass
@@ -289,7 +310,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
@@ -289,7 +315,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
self.assertAfterWithGeneratorInvariantsNoError(foo)


Expand All @@ -61,7 +136,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644
ContextmanagerAssertionMixin):
def testSingleArgInlineGeneratorSyntax(self):
with Nested(mock_contextmanager_generator()):
@@ -355,7 +376,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
@@ -355,7 +381,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
self.assertAfterWithManagerInvariantsNoError(mock_nested)


Expand All @@ -70,7 +145,71 @@ index 8e9ed8500c7..e1ebaa68b83 100644
def testSingleResource(self):
cm = mock_contextmanager_generator()
def shouldThrow():
@@ -550,7 +571,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
@@ -466,11 +492,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):

def testRaisedStopIteration2(self):
# From bug 1462485
- class cm(object):
- def __enter__(self):
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class cm(object):
+ def __enter__(self):
+ pass
+ def __exit__(self, type, value, traceback):
+ pass

def shouldThrow():
with cm():
@@ -507,11 +534,12 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):

def testRaisedGeneratorExit2(self):
# From bug 1462485
- class cm (object):
- def __enter__(self):
- pass
- def __exit__(self, type, value, traceback):
- pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class cm (object):
+ def __enter__(self):
+ pass
+ def __exit__(self, type, value, traceback):
+ pass

def shouldThrow():
with cm():
@@ -523,16 +551,17 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
# issue4589: __exit__ return code may raise an exception
# when looking at its truth value.

- class cm(object):
- def __init__(self, bool_conversion):
- class Bool:
- def __bool__(self):
- return bool_conversion()
- self.exit_result = Bool()
- def __enter__(self):
- return 3
- def __exit__(self, a, b, c):
- return self.exit_result
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class cm(object):
+ def __init__(self, bool_conversion):
+ class Bool:
+ def __bool__(self):
+ return bool_conversion()
+ self.exit_result = Bool()
+ def __enter__(self):
+ return 3
+ def __exit__(self, a, b, c):
+ return self.exit_result

def trueAsBool():
with cm(lambda: True):
@@ -550,7 +579,7 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase):
self.assertRaises(ZeroDivisionError, failAsBool)


Expand All @@ -79,7 +218,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644

def testWithBreak(self):
counter = 0
@@ -607,7 +628,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase):
@@ -607,7 +636,7 @@ class NonLocalFlowControlTestCase(unittest.TestCase):
self.fail("Didn't raise RuntimeError")


Expand All @@ -88,16 +227,71 @@ index 8e9ed8500c7..e1ebaa68b83 100644

def testSingleComplexTarget(self):
targets = {1: [0, 1, 2]}
@@ -651,7 +672,7 @@ class AssignmentTargetTestCase(unittest.TestCase):
@@ -621,15 +650,17 @@ class AssignmentTargetTestCase(unittest.TestCase):
keys = list(targets.keys())
keys.sort()
self.assertEqual(keys, [1, 2])
- class C: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
self.assertEqual(hasattr(blah, "foo"), True)

def testMultipleComplexTargets(self):
- class C:
- def __enter__(self): return 1, 2, 3
- def __exit__(self, t, v, tb): pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class C:
+ def __enter__(self): return 1, 2, 3
+ def __exit__(self, t, v, tb): pass
targets = {1: [0, 1, 2]}
with C() as (targets[1][0], targets[1][1], targets[1][2]):
self.assertEqual(targets, {1: [1, 2, 3]})
@@ -637,7 +668,8 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(targets, {1: [3, 2, 1]})
with C() as (targets[1], targets[2], targets[3]):
self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
- class B: pass
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class B: pass
blah = B()
with C() as (blah.one, blah.two, blah.three):
self.assertEqual(blah.one, 1)
@@ -651,12 +683,13 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(c, 4)


-class ExitSwallowsExceptionTestCase(unittest.TestCase):
+class ExitSwallowsExceptionTestCase(__TestCase):

def testExitTrueSwallowsException(self):
class AfricanSwallow:
@@ -676,7 +697,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
- class AfricanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return True
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class AfricanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return True
try:
with AfricanSwallow():
1/0
@@ -664,9 +697,10 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
self.fail("ZeroDivisionError should have been swallowed")

def testExitFalseDoesntSwallowException(self):
- class EuropeanSwallow:
- def __enter__(self): pass
- def __exit__(self, t, v, tb): return False
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class EuropeanSwallow:
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): return False
try:
with EuropeanSwallow():
1/0
@@ -676,7 +710,7 @@ class ExitSwallowsExceptionTestCase(unittest.TestCase):
self.fail("ZeroDivisionError should have been raised")


Expand All @@ -106,7 +300,7 @@ index 8e9ed8500c7..e1ebaa68b83 100644

class Dummy(object):
def __init__(self, value=None, gobble=False):
@@ -796,4 +817,4 @@ class NestedWith(unittest.TestCase):
@@ -796,4 +830,4 @@ class NestedWith(unittest.TestCase):


if __name__ == '__main__':
Expand Down
Loading
Loading