Skip to content

Commit dabbfc1

Browse files
[OrderedDict] Add bool(OrderedDict)
ghstack-source-id: 26d7fd7 Pull Request resolved: #155503
1 parent e13ce5e commit dabbfc1

7 files changed

+19
-2
lines changed

test/dynamo/test_dicts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,19 @@ def test_type(self):
14891489
self.assertIsInstance(d, self.thetype)
14901490
self.assertIs(type(d), self.thetype)
14911491

1492+
@make_dynamo_test
1493+
def test_bool(self):
1494+
p = self.thetype()
1495+
q = self.thetype({"a": 1, "b": 2})
1496+
if p:
1497+
self.fail("empty mapping must compare to False")
1498+
if not q:
1499+
self.fail("full mapping must compare to True")
1500+
if bool(p):
1501+
self.fail("empty mapping must compare to False")
1502+
if not bool(q):
1503+
self.fail("full mapping must compare to True")
1504+
14921505

14931506
class DictSubclassMethodsTests(DictMethodsTests):
14941507
thetype = SimpleDict

test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read

Whitespace-only changes.

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
741741
x = None
742742

743743
# __bool__ or __len__ is function
744-
if isinstance(x, UserMethodVariable):
744+
if isinstance(x, (GetAttrVariable, UserMethodVariable)):
745745
result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment]
746746
if isinstance(result, ConstantVariable) and isinstance(
747747
result.value, (bool, int)

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2308,6 +2308,10 @@ def call_sorted(
23082308
list_var.call_method(tx, "sort", [], kwargs)
23092309
return list_var
23102310

2311+
def call_bool(self, tx: "InstructionTranslator", obj: VariableTracker):
2312+
if isinstance(obj, (ConstDictVariable, UserDefinedDictVariable)):
2313+
return ConstantVariable.create(len(obj.items) > 0)
2314+
23112315
# neg is a constant fold function, so we only get here if constant fold is not valid
23122316
def call_neg(self, tx: "InstructionTranslator", a):
23132317
if isinstance(a, SymNodeVariable):
@@ -2534,7 +2538,7 @@ def call_not_(self, tx: "InstructionTranslator", a):
25342538
# Unwrap the underlying ConstDictVariable
25352539
if isinstance(a, DictViewVariable):
25362540
a = a.dv_dict
2537-
if isinstance(a, (ListVariable, ConstDictVariable)):
2541+
if isinstance(a, (ListVariable, ConstDictVariable, UserDefinedDictVariable)):
25382542
return ConstantVariable.create(len(a.items) == 0)
25392543

25402544
return None

0 commit comments

Comments
 (0)