Skip to content

Commit 4346314

Browse files
Wrap class definitions in set_fullgraph(False) in test_iter
ghstack-source-id: 59ad646 Pull-Request: #160278
1 parent dad3385 commit 4346314

8 files changed

+464
-120
lines changed

test/dynamo/cpython/3_13/test_iter.diff

Lines changed: 337 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py
2-
index 1b9f3cf7624..bad1ba94300 100644
2+
index 1b9f3cf7624..6560c7423a6 100644
33
--- a/test/dynamo/cpython/3_13/test_iter.py
44
+++ b/test/dynamo/cpython/3_13/test_iter.py
55
@@ -1,3 +1,60 @@
@@ -63,7 +63,7 @@ index 1b9f3cf7624..bad1ba94300 100644
6363
# Test iterators.
6464

6565
import sys
66-
@@ -104,12 +158,10 @@ class EmptyIterClass:
66+
@@ -104,12 +161,10 @@ class EmptyIterClass:
6767

6868
# Main test suite
6969

@@ -77,7 +77,7 @@ index 1b9f3cf7624..bad1ba94300 100644
7777
res = []
7878
while 1:
7979
try:
80-
@@ -121,8 +173,6 @@ class TestCase(unittest.TestCase):
80+
@@ -121,8 +176,6 @@ class TestCase(unittest.TestCase):
8181

8282
# Helper to check that a for loop generates a given sequence
8383
def check_for_loop(self, expr, seq, pickle=True):
@@ -86,15 +86,347 @@ index 1b9f3cf7624..bad1ba94300 100644
8686
res = []
8787
for val in expr:
8888
res.append(val)
89-
@@ -635,6 +685,7 @@ class TestCase(unittest.TestCase):
89+
@@ -261,19 +314,20 @@ class TestCase(unittest.TestCase):
90+
def run(builtin_name, item, sentinel=None):
91+
it = iter(item) if sentinel is None else iter(item, sentinel)
92+
93+
- class CustomStr:
94+
- def __init__(self, name, iterator):
95+
- self.name = name
96+
- self.iterator = iterator
97+
- def __hash__(self):
98+
- return hash(self.name)
99+
- def __eq__(self, other):
100+
- # Here we exhaust our iterator, possibly changing
101+
- # its `it_seq` pointer to NULL
102+
- # The `__reduce__` call should correctly get
103+
- # the pointers after this call
104+
- list(self.iterator)
105+
- return other == self.name
106+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
107+
+ class CustomStr:
108+
+ def __init__(self, name, iterator):
109+
+ self.name = name
110+
+ self.iterator = iterator
111+
+ def __hash__(self):
112+
+ return hash(self.name)
113+
+ def __eq__(self, other):
114+
+ # Here we exhaust our iterator, possibly changing
115+
+ # its `it_seq` pointer to NULL
116+
+ # The `__reduce__` call should correctly get
117+
+ # the pointers after this call
118+
+ list(self.iterator)
119+
+ return other == self.name
120+
121+
# del is required here
122+
# to not prematurely call __eq__ from
123+
@@ -323,9 +377,10 @@ class TestCase(unittest.TestCase):
124+
125+
# Test a new_style class with __iter__ but no next() method
126+
def test_new_style_iter_class(self):
127+
- class IterClass(object):
128+
- def __iter__(self):
129+
- return self
130+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
131+
+ class IterClass(object):
132+
+ def __iter__(self):
133+
+ return self
134+
self.assertRaises(TypeError, iter, IterClass())
135+
136+
# Test two-argument iter() with callable instance
137+
@@ -394,11 +449,12 @@ class TestCase(unittest.TestCase):
138+
139+
# Test exception propagation through sequence iterator
140+
def test_exception_sequence(self):
141+
- class MySequenceClass(SequenceClass):
142+
- def __getitem__(self, i):
143+
- if i == 10:
144+
- raise RuntimeError
145+
- return SequenceClass.__getitem__(self, i)
146+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
147+
+ class MySequenceClass(SequenceClass):
148+
+ def __getitem__(self, i):
149+
+ if i == 10:
150+
+ raise RuntimeError
151+
+ return SequenceClass.__getitem__(self, i)
152+
res = []
153+
try:
154+
for x in MySequenceClass(20):
155+
@@ -410,11 +466,12 @@ class TestCase(unittest.TestCase):
156+
157+
# Test for StopIteration from __getitem__
158+
def test_stop_sequence(self):
159+
- class MySequenceClass(SequenceClass):
160+
- def __getitem__(self, i):
161+
- if i == 10:
162+
- raise StopIteration
163+
- return SequenceClass.__getitem__(self, i)
164+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
165+
+ class MySequenceClass(SequenceClass):
166+
+ def __getitem__(self, i):
167+
+ if i == 10:
168+
+ raise StopIteration
169+
+ return SequenceClass.__getitem__(self, i)
170+
self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
171+
172+
# Test a big range
173+
@@ -541,32 +598,34 @@ class TestCase(unittest.TestCase):
174+
self.assertRaises(TypeError, filter, None, list)
175+
self.assertRaises(TypeError, filter, None, 42)
176+
177+
- class Boolean:
178+
- def __init__(self, truth):
179+
- self.truth = truth
180+
- def __bool__(self):
181+
- return self.truth
182+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
183+
+ class Boolean:
184+
+ def __init__(self, truth):
185+
+ self.truth = truth
186+
+ def __bool__(self):
187+
+ return self.truth
188+
bTrue = Boolean(True)
189+
bFalse = Boolean(False)
190+
191+
- class Seq:
192+
- def __init__(self, *args):
193+
- self.vals = args
194+
- def __iter__(self):
195+
- class SeqIter:
196+
- def __init__(self, vals):
197+
- self.vals = vals
198+
- self.i = 0
199+
- def __iter__(self):
200+
- return self
201+
- def __next__(self):
202+
- i = self.i
203+
- self.i = i + 1
204+
- if i < len(self.vals):
205+
- return self.vals[i]
206+
- else:
207+
- raise StopIteration
208+
- return SeqIter(self.vals)
209+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
210+
+ class Seq:
211+
+ def __init__(self, *args):
212+
+ self.vals = args
213+
+ def __iter__(self):
214+
+ class SeqIter:
215+
+ def __init__(self, vals):
216+
+ self.vals = vals
217+
+ self.i = 0
218+
+ def __iter__(self):
219+
+ return self
220+
+ def __next__(self):
221+
+ i = self.i
222+
+ self.i = i + 1
223+
+ if i < len(self.vals):
224+
+ return self.vals[i]
225+
+ else:
226+
+ raise StopIteration
227+
+ return SeqIter(self.vals)
228+
229+
seq = Seq(*([bTrue, bFalse] * 25))
230+
self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
231+
@@ -635,6 +694,7 @@ class TestCase(unittest.TestCase):
90232
pass
91233

92234
# Test zip()'s use of iterators.
93235
+ @skipIfTorchDynamo("infinite loop")
94236
def test_builtin_zip(self):
95237
self.assertEqual(list(zip()), [])
96238
self.assertEqual(list(zip(*[])), [])
97-
@@ -1187,4 +1238,4 @@ class TestCase(unittest.TestCase):
239+
@@ -653,17 +713,18 @@ class TestCase(unittest.TestCase):
240+
self.assertEqual(list(d.items()), list(zip(d, d.values())))
241+
242+
# Generate all ints starting at constructor arg.
243+
- class IntsFrom:
244+
- def __init__(self, start):
245+
- self.i = start
246+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
247+
+ class IntsFrom:
248+
+ def __init__(self, start):
249+
+ self.i = start
250+
251+
- def __iter__(self):
252+
- return self
253+
+ def __iter__(self):
254+
+ return self
255+
256+
- def __next__(self):
257+
- i = self.i
258+
- self.i = i+1
259+
- return i
260+
+ def __next__(self):
261+
+ i = self.i
262+
+ self.i = i+1
263+
+ return i
264+
265+
f = open(TESTFN, "w", encoding="utf-8")
266+
try:
267+
@@ -686,19 +747,20 @@ class TestCase(unittest.TestCase):
268+
self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
269+
270+
# Classes that lie about their lengths.
271+
- class NoGuessLen5:
272+
- def __getitem__(self, i):
273+
- if i >= 5:
274+
- raise IndexError
275+
- return i
276+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
277+
+ class NoGuessLen5:
278+
+ def __getitem__(self, i):
279+
+ if i >= 5:
280+
+ raise IndexError
281+
+ return i
282+
283+
- class Guess3Len5(NoGuessLen5):
284+
- def __len__(self):
285+
- return 3
286+
+ class Guess3Len5(NoGuessLen5):
287+
+ def __len__(self):
288+
+ return 3
289+
290+
- class Guess30Len5(NoGuessLen5):
291+
- def __len__(self):
292+
- return 30
293+
+ class Guess30Len5(NoGuessLen5):
294+
+ def __len__(self):
295+
+ return 30
296+
297+
def lzip(*args):
298+
return list(zip(*args))
299+
@@ -718,20 +780,21 @@ class TestCase(unittest.TestCase):
300+
301+
# This class inserts a Unicode object into its argument's natural
302+
# iteration, in the 3rd position.
303+
- class OhPhooey:
304+
- def __init__(self, seq):
305+
- self.it = iter(seq)
306+
- self.i = 0
307+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
308+
+ class OhPhooey:
309+
+ def __init__(self, seq):
310+
+ self.it = iter(seq)
311+
+ self.i = 0
312+
313+
- def __iter__(self):
314+
- return self
315+
+ def __iter__(self):
316+
+ return self
317+
318+
- def __next__(self):
319+
- i = self.i
320+
- self.i = i+1
321+
- if i == 2:
322+
- return "fooled you!"
323+
- return next(self.it)
324+
+ def __next__(self):
325+
+ i = self.i
326+
+ self.i = i+1
327+
+ if i == 2:
328+
+ return "fooled you!"
329+
+ return next(self.it)
330+
331+
f = open(TESTFN, "w", encoding="utf-8")
332+
try:
333+
@@ -895,29 +958,30 @@ class TestCase(unittest.TestCase):
334+
f.writelines({})
335+
336+
# Try a big chunk too.
337+
- class Iterator:
338+
- def __init__(self, start, finish):
339+
- self.start = start
340+
- self.finish = finish
341+
- self.i = self.start
342+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
343+
+ class Iterator:
344+
+ def __init__(self, start, finish):
345+
+ self.start = start
346+
+ self.finish = finish
347+
+ self.i = self.start
348+
349+
- def __next__(self):
350+
- if self.i >= self.finish:
351+
- raise StopIteration
352+
- result = str(self.i) + '\n'
353+
- self.i += 1
354+
- return result
355+
+ def __next__(self):
356+
+ if self.i >= self.finish:
357+
+ raise StopIteration
358+
+ result = str(self.i) + '\n'
359+
+ self.i += 1
360+
+ return result
361+
362+
- def __iter__(self):
363+
- return self
364+
+ def __iter__(self):
365+
+ return self
366+
367+
- class Whatever:
368+
- def __init__(self, start, finish):
369+
- self.start = start
370+
- self.finish = finish
371+
+ class Whatever:
372+
+ def __init__(self, start, finish):
373+
+ self.start = start
374+
+ self.finish = finish
375+
376+
- def __iter__(self):
377+
- return Iterator(self.start, self.finish)
378+
+ def __iter__(self):
379+
+ return Iterator(self.start, self.finish)
380+
381+
f.writelines(Whatever(6, 6+2000))
382+
f.close()
383+
@@ -990,15 +1054,16 @@ class TestCase(unittest.TestCase):
384+
385+
@cpython_only
386+
def test_ref_counting_behavior(self):
387+
- class C(object):
388+
- count = 0
389+
- def __new__(cls):
390+
- cls.count += 1
391+
- return object.__new__(cls)
392+
- def __del__(self):
393+
- cls = self.__class__
394+
- assert cls.count > 0
395+
- cls.count -= 1
396+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
397+
+ class C(object):
398+
+ count = 0
399+
+ def __new__(cls):
400+
+ cls.count += 1
401+
+ return object.__new__(cls)
402+
+ def __del__(self):
403+
+ cls = self.__class__
404+
+ assert cls.count > 0
405+
+ cls.count -= 1
406+
x = C()
407+
self.assertEqual(C.count, 1)
408+
del x
409+
@@ -1089,12 +1154,13 @@ class TestCase(unittest.TestCase):
410+
411+
def test_3720(self):
412+
# Avoid a crash, when an iterator deletes its next() method.
413+
- class BadIterator(object):
414+
- def __iter__(self):
415+
- return self
416+
- def __next__(self):
417+
- del BadIterator.__next__
418+
- return 1
419+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
420+
+ class BadIterator(object):
421+
+ def __iter__(self):
422+
+ return self
423+
+ def __next__(self):
424+
+ del BadIterator.__next__
425+
+ return 1
426+
427+
try:
428+
for i in BadIterator() :
429+
@@ -1187,4 +1253,4 @@ class TestCase(unittest.TestCase):
98430

99431

100432
if __name__ == "__main__":

0 commit comments

Comments
 (0)