diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 11dd75a47b48..ea78afaa0b1f 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -1551,6 +1551,23 @@ class OrderedDictMethodsTests(DictMethodsTests): # + move_to_end @make_dynamo_test + def test_move_to_end(self): + d = self.thetype.fromkeys("abcde") + self.assertEqual("".join(d), "abcde") + d.move_to_end("b") + self.assertEqual("".join(d), "acdeb") + + # Test OrderedDict.move_to_end + self.thetype.move_to_end(d, "a") + self.assertEqual("".join(d), "cdeba") + + # Test last=False + self.thetype.move_to_end(d, "a", last=False) + self.assertEqual("".join(d), "acdeb") + + # Test KeyError + self.assertRaises(KeyError, d.move_to_end, "f") + def test_cmp_eq_order(self): a = self.thetype.fromkeys("abc") b = self.thetype.fromkeys("bca") @@ -1583,6 +1600,38 @@ def test_binop_ior_return_type(self): self.assertIs(type(dict(d4).__ior__(d2)), dict) +class OrderedDictSubclassOverload(torch._dynamo.test_case.TestCase): + def setUp(self): + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = False + return super().tearDown() + + def assertEqual(self, x, y): + self.assertTrue(x == y, f"Expected {x} to be equal to {y}") + + def assertNotEqual(self, x, y): + self.assertFalse(x == y, f"Expected {x} to not be equal to {y}") + + class OrderedDictSubclass(OrderedDict): + def get(self, key, default=None, /): + return default + + def move_to_end(self, key, last=True, /): + # change the behavior to something else + self.pop(key) + + thetype = OrderedDictSubclass + + @make_dynamo_test + def test_move_to_end(self): + p = self.thetype({"a": 1, "b": 2, "c": 3}) + p.move_to_end("a") + self.assertEqual(list(p.keys()), list("bc")) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end_issue25406 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end_issue25406 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 546d1bc84f25..bee1d88bc1c7 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2436,7 +2436,7 @@ def BUILD_MAP_UNPACK(self, inst): items = self.popn(inst.argval) # ensure everything is a dict items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] - result = {} + result: dict[Any, Any] = {} for x in items: assert isinstance(x, ConstDictVariable) result.update(x.items) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index f3b92080a0ec..459cfe71a19b 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -245,13 +245,33 @@ def __init__( def make_hashable(key): return key if isinstance(key, Hashable) else Hashable(key) - self.items = {make_hashable(x): v for x, v in items.items()} + dict_cls = self._get_dict_cls_from_user_cls(user_cls) + self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) # need to reconstruct everything if the dictionary is an intermediate value # or if a pop/delitem was executed self.should_reconstruct_all = not is_from_local_source(self.source) self.original_items = items.copy() self.user_cls = user_cls + def _get_dict_cls_from_user_cls(self, user_cls): + accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) + + # avoid executing user code if user_cls is a dict subclass + if user_cls in accepted_dict_types: + dict_cls = user_cls + else: + # + dict_cls = next( + base for base in user_cls.__mro__ if base in accepted_dict_types + ) + assert dict_cls in accepted_dict_types, dict_cls + + # Use a dict instead as the call "defaultdict({make_hashable(x): v ..})" + # would fail as defaultdict expects a callable as first argument + if dict_cls is collections.defaultdict: + dict_cls = dict + return dict_cls + def as_proxy(self): return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} @@ -610,12 +630,23 @@ def call_method( return x elif name == "move_to_end": self.install_dict_keys_match_guard() - assert not kwargs and len(args) == 1 tx.output.side_effects.mutation(self) + if args[0] not in self: + raise_observed_exception(KeyError, tx) + + last = True + if len(args) == 2 and isinstance(args[1], ConstantVariable): + last = args[1].value + + if ( + kwargs + and "last" in kwargs + and isinstance(kwargs["last"], ConstantVariable) + ): + last = kwargs.get("last").value + key = Hashable(args[0]) - val = self.items[key] - self.items.pop(key) - self.items[key] = val + self.items.move_to_end(key, last=last) return ConstantVariable.create(None) elif name == "__eq__" and istype( self, ConstDictVariable diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index a2a792c2b415..78d22f07d889 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -414,6 +414,8 @@ def call_method( return BuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) + elif self.value is collections.OrderedDict and name == "move_to_end": + return args[0].call_method(tx, name, [*args[1:]], kwargs) elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"): return variables.ConstantVariable(self.value == args[0].value) elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):