Skip to content

Commit 22f9f31

Browse files
Update
[ghstack-poisoned]
2 parents c8f6a72 + a027969 commit 22f9f31

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

torch/_dynamo/variables/dicts.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -678,32 +678,24 @@ def call_method(
678678
elif name == "__or__":
679679
assert len(args) == 1
680680
# Dicts can only be unioned with other dicts or subclasses of dicts.
681-
# Sets can be unioned with other sets, frozensets or subclasses of sets.
682-
_raise = not (
683-
(
684-
istype(self, ConstDictVariable)
685-
and istype(
686-
args[0], (ConstDictVariable, variables.UserDefinedDictVariable)
687-
)
688-
)
689-
or (
690-
isinstance(self, SetVariable)
691-
and isinstance(
692-
args[0], (SetVariable, variables.UserDefinedSetVariable)
693-
)
694-
)
695-
)
696-
697-
if _raise:
681+
if not istype(
682+
args[0], (ConstDictVariable, variables.UserDefinedDictVariable)
683+
):
698684
msg = (
699685
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
700686
f"and '{args[0].python_type().__name__}'"
701687
)
702688
raise_observed_exception(TypeError, tx, args=[msg])
703689

690+
# Rule of thumb:
691+
# - If either user_cls is defaultdict, use dict
692+
# - If either is OrderedDict, use OrderedDict
693+
# - Otherwise, use dict
704694
ts = {self.user_cls, args[0].user_cls}
705695
user_cls = (
706-
collections.OrderedDict if collections.OrderedDict in ts else dict
696+
collections.OrderedDict
697+
if any(issubclass(t, collections.OrderedDict) for t in ts)
698+
else dict
707699
)
708700

709701
self.install_dict_keys_match_guard()

torch/_dynamo/variables/user_defined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,7 @@ def __init__(self, value, dict_vt=None, **kwargs):
18691869
else dict
18701870
)
18711871
self._dict_vt = variables.ConstDictVariable(
1872-
user_cls(), user_cls=user_cls, mutation_type=ValueMutationNew()
1872+
{}, type(value), mutation_type=ValueMutationNew()
18731873
)
18741874
self._dict_methods = dict_methods
18751875

0 commit comments

Comments
 (0)