Skip to content

[Dynamo] Add CPython default dict tests #155263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/guilhermeleobas/167/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions test/dynamo/cpython/3_13/test_defaultdict.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
diff --git a/test/dynamo/cpython/3_13/test_defaultdict.py b/test/dynamo/cpython/3_13/test_defaultdict.py
index bdbe9b81e8f..d55f1dc54c6 100644
--- a/test/dynamo/cpython/3_13/test_defaultdict.py
+++ b/test/dynamo/cpython/3_13/test_defaultdict.py
@@ -1,3 +1,60 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_defaultdict.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ run_tests,
+)
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
+
"""Unit tests for collections.defaultdict."""

import copy
@@ -9,7 +66,7 @@ from collections import defaultdict
def foobar():
return list

-class TestDefaultDict(unittest.TestCase):
+class TestDefaultDict(__TestCase):

def test_basic(self):
d1 = defaultdict()
@@ -127,11 +184,12 @@ class TestDefaultDict(unittest.TestCase):

def test_recursive_repr(self):
# Issue2045: stack overflow when default_factory is a bound method
- class sub(defaultdict):
- def __init__(self):
- self.default_factory = self._factory
- def _factory(self):
- return []
+ with torch._dynamo.set_fullgraph(fullgraph=False):
+ class sub(defaultdict):
+ def __init__(self):
+ self.default_factory = self._factory
+ def _factory(self):
+ return []
d = sub()
self.assertRegex(repr(d),
r"sub\(<bound method .*sub\._factory "
@@ -187,4 +245,4 @@ class TestDefaultDict(unittest.TestCase):
i |= None

if __name__ == "__main__":
- unittest.main()
+ run_tests()
248 changes: 248 additions & 0 deletions test/dynamo/cpython/3_13/test_defaultdict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# ======= BEGIN Dynamo patch =======
# Owner(s): ["module: dynamo"]

# ruff: noqa
# flake8: noqa

# Test copied from
# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_defaultdict.py

import sys
import torch
import torch._dynamo.test_case
import unittest
from torch._dynamo.test_case import CPythonTestCase
from torch.testing._internal.common_utils import (
run_tests,
)

__TestCase = CPythonTestCase


# redirect import statements
import sys
import importlib.abc

redirect_imports = (
"test.mapping_tests",
"test.typinganndata",
"test.test_grammar",
"test.test_math",
"test.test_iter",
"test.typinganndata.ann_module",
)

class RedirectImportFinder(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
# Check if the import is the problematic one
if fullname in redirect_imports:
try:
# Attempt to import the standalone module
name = fullname.removeprefix("test.")
r = importlib.import_module(name)
# Redirect the module in sys.modules
sys.modules[fullname] = r
# Return a module spec from the found module
return importlib.util.find_spec(name)
except ImportError:
return None
return None

# Add the custom finder to sys.meta_path
sys.meta_path.insert(0, RedirectImportFinder())


# ======= END DYNAMO PATCH =======


"""Unit tests for collections.defaultdict."""

import copy
import pickle
import unittest

from collections import defaultdict

def foobar():
return list

class TestDefaultDict(__TestCase):

def test_basic(self):
d1 = defaultdict()
self.assertEqual(d1.default_factory, None)
d1.default_factory = list
d1[12].append(42)
self.assertEqual(d1, {12: [42]})
d1[12].append(24)
self.assertEqual(d1, {12: [42, 24]})
d1[13]
d1[14]
self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
self.assertTrue(d1[12] is not d1[13] is not d1[14])
d2 = defaultdict(list, foo=1, bar=2)
self.assertEqual(d2.default_factory, list)
self.assertEqual(d2, {"foo": 1, "bar": 2})
self.assertEqual(d2["foo"], 1)
self.assertEqual(d2["bar"], 2)
self.assertEqual(d2[42], [])
self.assertIn("foo", d2)
self.assertIn("foo", d2.keys())
self.assertIn("bar", d2)
self.assertIn("bar", d2.keys())
self.assertIn(42, d2)
self.assertIn(42, d2.keys())
self.assertNotIn(12, d2)
self.assertNotIn(12, d2.keys())
d2.default_factory = None
self.assertEqual(d2.default_factory, None)
try:
d2[15]
except KeyError as err:
self.assertEqual(err.args, (15,))
else:
self.fail("d2[15] didn't raise KeyError")
self.assertRaises(TypeError, defaultdict, 1)

def test_missing(self):
d1 = defaultdict()
self.assertRaises(KeyError, d1.__missing__, 42)
d1.default_factory = list
self.assertEqual(d1.__missing__(42), [])

def test_repr(self):
d1 = defaultdict()
self.assertEqual(d1.default_factory, None)
self.assertEqual(repr(d1), "defaultdict(None, {})")
self.assertEqual(eval(repr(d1)), d1)
d1[11] = 41
self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
d2 = defaultdict(int)
self.assertEqual(d2.default_factory, int)
d2[12] = 42
self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
def foo(): return 43
d3 = defaultdict(foo)
self.assertTrue(d3.default_factory is foo)
d3[13]
self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))

def test_copy(self):
d1 = defaultdict()
d2 = d1.copy()
self.assertEqual(type(d2), defaultdict)
self.assertEqual(d2.default_factory, None)
self.assertEqual(d2, {})
d1.default_factory = list
d3 = d1.copy()
self.assertEqual(type(d3), defaultdict)
self.assertEqual(d3.default_factory, list)
self.assertEqual(d3, {})
d1[42]
d4 = d1.copy()
self.assertEqual(type(d4), defaultdict)
self.assertEqual(d4.default_factory, list)
self.assertEqual(d4, {42: []})
d4[12]
self.assertEqual(d4, {42: [], 12: []})

# Issue 6637: Copy fails for empty default dict
d = defaultdict()
d['a'] = 42
e = d.copy()
self.assertEqual(e['a'], 42)

def test_shallow_copy(self):
d1 = defaultdict(foobar, {1: 1})
d2 = copy.copy(d1)
self.assertEqual(d2.default_factory, foobar)
self.assertEqual(d2, d1)
d1.default_factory = list
d2 = copy.copy(d1)
self.assertEqual(d2.default_factory, list)
self.assertEqual(d2, d1)

def test_deep_copy(self):
d1 = defaultdict(foobar, {1: [1]})
d2 = copy.deepcopy(d1)
self.assertEqual(d2.default_factory, foobar)
self.assertEqual(d2, d1)
self.assertTrue(d1[1] is not d2[1])
d1.default_factory = list
d2 = copy.deepcopy(d1)
self.assertEqual(d2.default_factory, list)
self.assertEqual(d2, d1)

def test_keyerror_without_factory(self):
d1 = defaultdict()
try:
d1[(1,)]
except KeyError as err:
self.assertEqual(err.args[0], (1,))
else:
self.fail("expected KeyError")

def test_recursive_repr(self):
# Issue2045: stack overflow when default_factory is a bound method
with torch._dynamo.set_fullgraph(fullgraph=False):
class sub(defaultdict):
def __init__(self):
self.default_factory = self._factory
def _factory(self):
return []
d = sub()
self.assertRegex(repr(d),
r"sub\(<bound method .*sub\._factory "
r"of sub\(\.\.\., \{\}\)>, \{\}\)")

def test_callable_arg(self):
self.assertRaises(TypeError, defaultdict, {})

def test_pickling(self):
d = defaultdict(int)
d[1]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
s = pickle.dumps(d, proto)
o = pickle.loads(s)
self.assertEqual(d, o)

def test_union(self):
i = defaultdict(int, {1: 1, 2: 2})
s = defaultdict(str, {0: "zero", 1: "one"})

i_s = i | s
self.assertIs(i_s.default_factory, int)
self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
self.assertEqual(list(i_s), [1, 2, 0])

s_i = s | i
self.assertIs(s_i.default_factory, str)
self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
self.assertEqual(list(s_i), [0, 1, 2])

i_ds = i | dict(s)
self.assertIs(i_ds.default_factory, int)
self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
self.assertEqual(list(i_ds), [1, 2, 0])

ds_i = dict(s) | i
self.assertIs(ds_i.default_factory, int)
self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
self.assertEqual(list(ds_i), [0, 1, 2])

with self.assertRaises(TypeError):
i | list(s.items())
with self.assertRaises(TypeError):
list(s.items()) | i

# We inherit a fine |= from dict, so just a few sanity checks here:
i |= list(s.items())
self.assertIs(i.default_factory, int)
self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
self.assertEqual(list(i), [1, 2, 0])

with self.assertRaises(TypeError):
i |= None

if __name__ == "__main__":
run_tests()
Loading