Skip to content

Commit 12af9c8

Browse files
[OrderedDict] Add bool(OrderedDict)
ghstack-source-id: 886577f Pull Request resolved: #155503
1 parent 83ca1d2 commit 12af9c8

7 files changed

+20
-3
lines changed

test/dynamo/test_dicts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,19 @@ def test_type(self):
14871487
self.assertIsInstance(d, self.thetype)
14881488
self.assertIs(type(d), self.thetype)
14891489

1490+
@make_dynamo_test
1491+
def test_bool(self):
1492+
p = self.thetype()
1493+
q = self.thetype({"a": 1, "b": 2})
1494+
if p:
1495+
self.fail("empty mapping must compare to False")
1496+
if not q:
1497+
self.fail("full mapping must compare to True")
1498+
if bool(p):
1499+
self.fail("empty mapping must compare to False")
1500+
if not bool(q):
1501+
self.fail("full mapping must compare to True")
1502+
14901503

14911504
class DictSubclassMethodsTests(DictMethodsTests):
14921505
thetype = SimpleDict

test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read

Whitespace-only changes.

torch/_dynamo/symbolic_convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
741741
x = None
742742

743743
# __bool__ or __len__ is function
744-
if isinstance(x, UserMethodVariable):
744+
if isinstance(x, (GetAttrVariable, UserMethodVariable)):
745745
result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment]
746746
if isinstance(result, ConstantVariable) and isinstance(
747747
result.value, (bool, int)
@@ -2416,7 +2416,7 @@ def BUILD_MAP_UNPACK(self, inst):
24162416
items = self.popn(inst.argval)
24172417
# ensure everything is a dict
24182418
items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type]
2419-
result = {}
2419+
result: dict[Any, Any] = {}
24202420
for x in items:
24212421
assert isinstance(x, ConstDictVariable)
24222422
result.update(x.items)

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,10 @@ def call_sorted(
23052305
list_var.call_method(tx, "sort", [], kwargs)
23062306
return list_var
23072307

2308+
def call_bool(self, tx: "InstructionTranslator", obj: VariableTracker):
2309+
if isinstance(obj, (ConstDictVariable, UserDefinedDictVariable)):
2310+
return ConstantVariable.create(len(obj.items) > 0)
2311+
23082312
# neg is a constant fold function, so we only get here if constant fold is not valid
23092313
def call_neg(self, tx: "InstructionTranslator", a):
23102314
if isinstance(a, SymNodeVariable):
@@ -2531,7 +2535,7 @@ def call_not_(self, tx: "InstructionTranslator", a):
25312535
# Unwrap the underlying ConstDictVariable
25322536
if isinstance(a, DictViewVariable):
25332537
a = a.dv_dict
2534-
if isinstance(a, (ListVariable, ConstDictVariable)):
2538+
if isinstance(a, (ListVariable, ConstDictVariable, UserDefinedDictVariable)):
25352539
return ConstantVariable.create(len(a.items) == 0)
25362540

25372541
return None

0 commit comments

Comments
 (0)