diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index eae7dfcc6816..756d75972b6b 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -1538,6 +1538,19 @@ def test_type(self): self.assertIsInstance(d, self.thetype) self.assertIs(type(d), self.thetype) + @make_dynamo_test + def test_bool(self): + p = self.thetype() + q = self.thetype({"a": 1, "b": 2}) + if p: + self.fail("empty mapping must compare to False") + if not q: + self.fail("full mapping must compare to True") + if bool(p): + self.fail("empty mapping must compare to False") + if not bool(q): + self.fail("full mapping must compare to True") + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index bee1d88bc1c7..a8921a76ccb1 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -760,7 +760,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): x = None # __bool__ or __len__ is function - if isinstance(x, UserMethodVariable): + if isinstance(x, (GetAttrVariable, UserMethodVariable)): result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment] if isinstance(result, ConstantVariable) and isinstance( result.value, (bool, int) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index d2b5b8c40077..06c65d32214e 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1493,6 +1493,9 @@ def call_bool(self, tx: "InstructionTranslator", arg): assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat)) return SymNodeVariable.create(tx, arg.as_proxy() != 0) + if isinstance(arg, (ConstDictVariable, UserDefinedDictVariable)): + return ConstantVariable.create(len(arg.items) > 0) + # TODO handle more cases and merge this with this with `generic_jump`. def call_str(self, tx: "InstructionTranslator", arg): @@ -2791,7 +2794,7 @@ def call_not_(self, tx: "InstructionTranslator", a): # Unwrap the underlying ConstDictVariable if isinstance(a, DictViewVariable): a = a.dv_dict - if isinstance(a, (ListVariable, ConstDictVariable)): + if isinstance(a, (ListVariable, ConstDictVariable, UserDefinedDictVariable)): return ConstantVariable.create(len(a.items) == 0) return None