Skip to content

Commit 61ac089

Browse files
[Dynamo] Add CPython default dict tests
ghstack-source-id: bb4f367 Pull Request resolved: #155263
1 parent da1f608 commit 61ac089

12 files changed

+346
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
diff --git a/test/dynamo/cpython/3_13/test_defaultdict.py b/test/dynamo/cpython/3_13/test_defaultdict.py
2+
index bdbe9b81e8f..d55f1dc54c6 100644
3+
--- a/test/dynamo/cpython/3_13/test_defaultdict.py
4+
+++ b/test/dynamo/cpython/3_13/test_defaultdict.py
5+
@@ -1,3 +1,60 @@
6+
+# ======= BEGIN Dynamo patch =======
7+
+# Owner(s): ["module: dynamo"]
8+
+
9+
+# ruff: noqa
10+
+# flake8: noqa
11+
+
12+
+# Test copied from
13+
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_defaultdict.py
14+
+
15+
+import sys
16+
+import torch
17+
+import torch._dynamo.test_case
18+
+import unittest
19+
+from torch._dynamo.test_case import CPythonTestCase
20+
+from torch.testing._internal.common_utils import (
21+
+ run_tests,
22+
+)
23+
+
24+
+__TestCase = CPythonTestCase
25+
+
26+
+
27+
+# redirect import statements
28+
+import sys
29+
+import importlib.abc
30+
+
31+
+redirect_imports = (
32+
+ "test.mapping_tests",
33+
+ "test.typinganndata",
34+
+ "test.test_grammar",
35+
+ "test.test_math",
36+
+ "test.test_iter",
37+
+ "test.typinganndata.ann_module",
38+
+)
39+
+
40+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
41+
+ def find_spec(self, fullname, path, target=None):
42+
+ # Check if the import is the problematic one
43+
+ if fullname in redirect_imports:
44+
+ try:
45+
+ # Attempt to import the standalone module
46+
+ name = fullname.removeprefix("test.")
47+
+ r = importlib.import_module(name)
48+
+ # Redirect the module in sys.modules
49+
+ sys.modules[fullname] = r
50+
+ # Return a module spec from the found module
51+
+ return importlib.util.find_spec(name)
52+
+ except ImportError:
53+
+ return None
54+
+ return None
55+
+
56+
+# Add the custom finder to sys.meta_path
57+
+sys.meta_path.insert(0, RedirectImportFinder())
58+
+
59+
+
60+
+# ======= END DYNAMO PATCH =======
61+
+
62+
+
63+
"""Unit tests for collections.defaultdict."""
64+
65+
import copy
66+
@@ -9,7 +66,7 @@ from collections import defaultdict
67+
def foobar():
68+
return list
69+
70+
-class TestDefaultDict(unittest.TestCase):
71+
+class TestDefaultDict(__TestCase):
72+
73+
def test_basic(self):
74+
d1 = defaultdict()
75+
@@ -127,11 +184,12 @@ class TestDefaultDict(unittest.TestCase):
76+
77+
def test_recursive_repr(self):
78+
# Issue2045: stack overflow when default_factory is a bound method
79+
- class sub(defaultdict):
80+
- def __init__(self):
81+
- self.default_factory = self._factory
82+
- def _factory(self):
83+
- return []
84+
+ with torch._dynamo.set_fullgraph(fullgraph=False):
85+
+ class sub(defaultdict):
86+
+ def __init__(self):
87+
+ self.default_factory = self._factory
88+
+ def _factory(self):
89+
+ return []
90+
d = sub()
91+
self.assertRegex(repr(d),
92+
r"sub\(<bound method .*sub\._factory "
93+
@@ -187,4 +245,4 @@ class TestDefaultDict(unittest.TestCase):
94+
i |= None
95+
96+
if __name__ == "__main__":
97+
- unittest.main()
98+
+ run_tests()
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# ======= BEGIN Dynamo patch =======
2+
# Owner(s): ["module: dynamo"]
3+
4+
# ruff: noqa
5+
# flake8: noqa
6+
7+
# Test copied from
8+
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_defaultdict.py
9+
10+
import sys
11+
import torch
12+
import torch._dynamo.test_case
13+
import unittest
14+
from torch._dynamo.test_case import CPythonTestCase
15+
from torch.testing._internal.common_utils import (
16+
run_tests,
17+
)
18+
19+
__TestCase = CPythonTestCase
20+
21+
22+
# redirect import statements
23+
import sys
24+
import importlib.abc
25+
26+
redirect_imports = (
27+
"test.mapping_tests",
28+
"test.typinganndata",
29+
"test.test_grammar",
30+
"test.test_math",
31+
"test.test_iter",
32+
"test.typinganndata.ann_module",
33+
)
34+
35+
class RedirectImportFinder(importlib.abc.MetaPathFinder):
36+
def find_spec(self, fullname, path, target=None):
37+
# Check if the import is the problematic one
38+
if fullname in redirect_imports:
39+
try:
40+
# Attempt to import the standalone module
41+
name = fullname.removeprefix("test.")
42+
r = importlib.import_module(name)
43+
# Redirect the module in sys.modules
44+
sys.modules[fullname] = r
45+
# Return a module spec from the found module
46+
return importlib.util.find_spec(name)
47+
except ImportError:
48+
return None
49+
return None
50+
51+
# Add the custom finder to sys.meta_path
52+
sys.meta_path.insert(0, RedirectImportFinder())
53+
54+
55+
# ======= END DYNAMO PATCH =======
56+
57+
58+
"""Unit tests for collections.defaultdict."""
59+
60+
import copy
61+
import pickle
62+
import unittest
63+
64+
from collections import defaultdict
65+
66+
def foobar():
67+
return list
68+
69+
class TestDefaultDict(__TestCase):
70+
71+
def test_basic(self):
72+
d1 = defaultdict()
73+
self.assertEqual(d1.default_factory, None)
74+
d1.default_factory = list
75+
d1[12].append(42)
76+
self.assertEqual(d1, {12: [42]})
77+
d1[12].append(24)
78+
self.assertEqual(d1, {12: [42, 24]})
79+
d1[13]
80+
d1[14]
81+
self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
82+
self.assertTrue(d1[12] is not d1[13] is not d1[14])
83+
d2 = defaultdict(list, foo=1, bar=2)
84+
self.assertEqual(d2.default_factory, list)
85+
self.assertEqual(d2, {"foo": 1, "bar": 2})
86+
self.assertEqual(d2["foo"], 1)
87+
self.assertEqual(d2["bar"], 2)
88+
self.assertEqual(d2[42], [])
89+
self.assertIn("foo", d2)
90+
self.assertIn("foo", d2.keys())
91+
self.assertIn("bar", d2)
92+
self.assertIn("bar", d2.keys())
93+
self.assertIn(42, d2)
94+
self.assertIn(42, d2.keys())
95+
self.assertNotIn(12, d2)
96+
self.assertNotIn(12, d2.keys())
97+
d2.default_factory = None
98+
self.assertEqual(d2.default_factory, None)
99+
try:
100+
d2[15]
101+
except KeyError as err:
102+
self.assertEqual(err.args, (15,))
103+
else:
104+
self.fail("d2[15] didn't raise KeyError")
105+
self.assertRaises(TypeError, defaultdict, 1)
106+
107+
def test_missing(self):
108+
d1 = defaultdict()
109+
self.assertRaises(KeyError, d1.__missing__, 42)
110+
d1.default_factory = list
111+
self.assertEqual(d1.__missing__(42), [])
112+
113+
def test_repr(self):
114+
d1 = defaultdict()
115+
self.assertEqual(d1.default_factory, None)
116+
self.assertEqual(repr(d1), "defaultdict(None, {})")
117+
self.assertEqual(eval(repr(d1)), d1)
118+
d1[11] = 41
119+
self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
120+
d2 = defaultdict(int)
121+
self.assertEqual(d2.default_factory, int)
122+
d2[12] = 42
123+
self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
124+
def foo(): return 43
125+
d3 = defaultdict(foo)
126+
self.assertTrue(d3.default_factory is foo)
127+
d3[13]
128+
self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
129+
130+
def test_copy(self):
131+
d1 = defaultdict()
132+
d2 = d1.copy()
133+
self.assertEqual(type(d2), defaultdict)
134+
self.assertEqual(d2.default_factory, None)
135+
self.assertEqual(d2, {})
136+
d1.default_factory = list
137+
d3 = d1.copy()
138+
self.assertEqual(type(d3), defaultdict)
139+
self.assertEqual(d3.default_factory, list)
140+
self.assertEqual(d3, {})
141+
d1[42]
142+
d4 = d1.copy()
143+
self.assertEqual(type(d4), defaultdict)
144+
self.assertEqual(d4.default_factory, list)
145+
self.assertEqual(d4, {42: []})
146+
d4[12]
147+
self.assertEqual(d4, {42: [], 12: []})
148+
149+
# Issue 6637: Copy fails for empty default dict
150+
d = defaultdict()
151+
d['a'] = 42
152+
e = d.copy()
153+
self.assertEqual(e['a'], 42)
154+
155+
def test_shallow_copy(self):
156+
d1 = defaultdict(foobar, {1: 1})
157+
d2 = copy.copy(d1)
158+
self.assertEqual(d2.default_factory, foobar)
159+
self.assertEqual(d2, d1)
160+
d1.default_factory = list
161+
d2 = copy.copy(d1)
162+
self.assertEqual(d2.default_factory, list)
163+
self.assertEqual(d2, d1)
164+
165+
def test_deep_copy(self):
166+
d1 = defaultdict(foobar, {1: [1]})
167+
d2 = copy.deepcopy(d1)
168+
self.assertEqual(d2.default_factory, foobar)
169+
self.assertEqual(d2, d1)
170+
self.assertTrue(d1[1] is not d2[1])
171+
d1.default_factory = list
172+
d2 = copy.deepcopy(d1)
173+
self.assertEqual(d2.default_factory, list)
174+
self.assertEqual(d2, d1)
175+
176+
def test_keyerror_without_factory(self):
177+
d1 = defaultdict()
178+
try:
179+
d1[(1,)]
180+
except KeyError as err:
181+
self.assertEqual(err.args[0], (1,))
182+
else:
183+
self.fail("expected KeyError")
184+
185+
def test_recursive_repr(self):
186+
# Issue2045: stack overflow when default_factory is a bound method
187+
with torch._dynamo.set_fullgraph(fullgraph=False):
188+
class sub(defaultdict):
189+
def __init__(self):
190+
self.default_factory = self._factory
191+
def _factory(self):
192+
return []
193+
d = sub()
194+
self.assertRegex(repr(d),
195+
r"sub\(<bound method .*sub\._factory "
196+
r"of sub\(\.\.\., \{\}\)>, \{\}\)")
197+
198+
def test_callable_arg(self):
199+
self.assertRaises(TypeError, defaultdict, {})
200+
201+
def test_pickling(self):
202+
d = defaultdict(int)
203+
d[1]
204+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
205+
s = pickle.dumps(d, proto)
206+
o = pickle.loads(s)
207+
self.assertEqual(d, o)
208+
209+
def test_union(self):
210+
i = defaultdict(int, {1: 1, 2: 2})
211+
s = defaultdict(str, {0: "zero", 1: "one"})
212+
213+
i_s = i | s
214+
self.assertIs(i_s.default_factory, int)
215+
self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
216+
self.assertEqual(list(i_s), [1, 2, 0])
217+
218+
s_i = s | i
219+
self.assertIs(s_i.default_factory, str)
220+
self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
221+
self.assertEqual(list(s_i), [0, 1, 2])
222+
223+
i_ds = i | dict(s)
224+
self.assertIs(i_ds.default_factory, int)
225+
self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
226+
self.assertEqual(list(i_ds), [1, 2, 0])
227+
228+
ds_i = dict(s) | i
229+
self.assertIs(ds_i.default_factory, int)
230+
self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
231+
self.assertEqual(list(ds_i), [0, 1, 2])
232+
233+
with self.assertRaises(TypeError):
234+
i | list(s.items())
235+
with self.assertRaises(TypeError):
236+
list(s.items()) | i
237+
238+
# We inherit a fine |= from dict, so just a few sanity checks here:
239+
i |= list(s.items())
240+
self.assertIs(i.default_factory, int)
241+
self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
242+
self.assertEqual(list(i), [1, 2, 0])
243+
244+
with self.assertRaises(TypeError):
245+
i |= None
246+
247+
if __name__ == "__main__":
248+
run_tests()

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_basic

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_callable_arg

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_copy

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_deep_copy

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_missing

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_pickling

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_repr

Whitespace-only changes.

0 commit comments

Comments
 (0)