From ab42690bbf6578a4d9cc49d12fb860793147edfb Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 9 Jun 2025 20:45:27 -0300 Subject: [PATCH] Update [ghstack-poisoned] --- test/dynamo/test_dicts.py | 13 +++++++++++++ ...thon313-test_dict-SubclassMappingTests.test_bool | 0 ...thon313-test_dict-SubclassMappingTests.test_read | 0 ...dered_dict-CPythonSubclassMappingTests.test_bool | 0 ...dered_dict-CPythonSubclassMappingTests.test_read | 0 torch/_dynamo/symbolic_convert.py | 4 ++-- torch/_dynamo/variables/builtin.py | 6 +++++- 7 files changed, 20 insertions(+), 3 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool delete mode 100644 test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read delete mode 100644 test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool delete mode 100644 test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 07dfe587ec36..076dad3e48c3 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -1487,6 +1487,19 @@ def test_type(self): self.assertIsInstance(d, self.thetype) self.assertIs(type(d), self.thetype) + @make_dynamo_test + def test_bool(self): + p = self.thetype() + q = self.thetype({"a": 1, "b": 2}) + if p: + self.fail("empty mapping must compare to False") + if not q: + self.fail("full mapping must compare to True") + if bool(p): + self.fail("empty mapping must compare to False") + if not bool(q): + self.fail("full mapping must compare to True") + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index a0b2efb0a0b6..213943111bdd 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -741,7 +741,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): x = None # __bool__ or __len__ is function - if isinstance(x, UserMethodVariable): + if isinstance(x, (GetAttrVariable, UserMethodVariable)): result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment] if isinstance(result, ConstantVariable) and isinstance( result.value, (bool, int) @@ -2416,7 +2416,7 @@ def BUILD_MAP_UNPACK(self, inst): items = self.popn(inst.argval) # ensure everything is a dict items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] - result = {} + result: dict[Any, Any] = {} for x in items: assert isinstance(x, ConstDictVariable) result.update(x.items) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 903376e4894e..61d82847cd5f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2305,6 +2305,10 @@ def call_sorted( list_var.call_method(tx, "sort", [], kwargs) return list_var + def call_bool(self, tx: "InstructionTranslator", obj: VariableTracker): + if isinstance(obj, (ConstDictVariable, UserDefinedDictVariable)): + return ConstantVariable.create(len(obj.items) > 0) + # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): @@ -2531,7 +2535,7 @@ def call_not_(self, tx: "InstructionTranslator", a): # Unwrap the underlying ConstDictVariable if isinstance(a, DictViewVariable): a = a.dv_dict - if isinstance(a, (ListVariable, ConstDictVariable)): + if isinstance(a, (ListVariable, ConstDictVariable, UserDefinedDictVariable)): return ConstantVariable.create(len(a.items) == 0) return None