Skip to content

[OrderedDict] Implement OrderedDict.move_to_end(key, last=False) #155152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: gh/guilhermeleobas/164/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/dynamo/test_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,23 @@ class OrderedDictMethodsTests(DictMethodsTests):
# + move_to_end

@make_dynamo_test
def test_move_to_end(self):
d = self.thetype.fromkeys("abcde")
self.assertEqual("".join(d), "abcde")
d.move_to_end("b")
self.assertEqual("".join(d), "acdeb")

# Test OrderedDict.move_to_end
self.thetype.move_to_end(d, "a")
self.assertEqual("".join(d), "cdeba")

# Test last=False
self.thetype.move_to_end(d, "a", last=False)
self.assertEqual("".join(d), "acdeb")

# Test KeyError
self.assertRaises(KeyError, d.move_to_end, "f")

def test_cmp_eq_order(self):
a = self.thetype.fromkeys("abc")
b = self.thetype.fromkeys("bca")
Expand Down Expand Up @@ -1583,6 +1600,38 @@ def test_binop_ior_return_type(self):
self.assertIs(type(dict(d4).__ior__(d2)), dict)


class OrderedDictSubclassOverload(torch._dynamo.test_case.TestCase):
def setUp(self):
torch._dynamo.config.enable_trace_unittest = True
super().setUp()

def tearDown(self):
torch._dynamo.config.enable_trace_unittest = False
return super().tearDown()

def assertEqual(self, x, y):
self.assertTrue(x == y, f"Expected {x} to be equal to {y}")

def assertNotEqual(self, x, y):
self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")

class OrderedDictSubclass(OrderedDict):
def get(self, key, default=None, /):
return default

def move_to_end(self, key, last=True, /):
# change the behavior to something else
self.pop(key)

thetype = OrderedDictSubclass

@make_dynamo_test
def test_move_to_end(self):
p = self.thetype({"a": 1, "b": 2, "c": 3})
p.move_to_end("a")
self.assertEqual(list(p.keys()), list("bc"))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2436,7 +2436,7 @@ def BUILD_MAP_UNPACK(self, inst):
items = self.popn(inst.argval)
# ensure everything is a dict
items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type]
result = {}
result: dict[Any, Any] = {}
for x in items:
assert isinstance(x, ConstDictVariable)
result.update(x.items)
Expand Down
41 changes: 36 additions & 5 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,33 @@ def __init__(
def make_hashable(key):
return key if isinstance(key, Hashable) else Hashable(key)

self.items = {make_hashable(x): v for x, v in items.items()}
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = not is_from_local_source(self.source)
self.original_items = items.copy()
self.user_cls = user_cls

def _get_dict_cls_from_user_cls(self, user_cls):
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)

# avoid executing user code if user_cls is a dict subclass
if user_cls in accepted_dict_types:
dict_cls = user_cls
else:
# <Subclass, ..., dict, object>
dict_cls = next(
base for base in user_cls.__mro__ if base in accepted_dict_types
)
assert dict_cls in accepted_dict_types, dict_cls

# Use a dict instead as the call "defaultdict({make_hashable(x): v ..})"
# would fail as defaultdict expects a callable as first argument
if dict_cls is collections.defaultdict:
dict_cls = dict
return dict_cls

def as_proxy(self):
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}

Expand Down Expand Up @@ -610,12 +630,23 @@ def call_method(
return x
elif name == "move_to_end":
self.install_dict_keys_match_guard()
assert not kwargs and len(args) == 1
tx.output.side_effects.mutation(self)
if args[0] not in self:
raise_observed_exception(KeyError, tx)

last = True
if len(args) == 2 and isinstance(args[1], ConstantVariable):
last = args[1].value

if (
kwargs
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value

key = Hashable(args[0])
val = self.items[key]
self.items.pop(key)
self.items[key] = val
Comment on lines -616 to -618
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was wrong with the old implementation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's because of last=... keyword. It would be hard to implement move_to_end(key, last=False).

self.items.move_to_end(key, last=last)
return ConstantVariable.create(None)
elif name == "__eq__" and istype(
self, ConstDictVariable
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ def call_method(
return BuiltinVariable.call_custom_dict_fromkeys(
tx, self.value, *args, **kwargs
)
elif self.value is collections.OrderedDict and name == "move_to_end":
return args[0].call_method(tx, name, [*args[1:]], kwargs)
elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"):
return variables.ConstantVariable(self.value == args[0].value)
elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
Expand Down
Loading