From 3dd9fe3f79081f3f348ccd92284087c51c9ef40f Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 29 Jul 2025 11:33:32 -0300 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- ...on313-test_collections-TestCounter.test_eq | 0 ...on313-test_collections-TestCounter.test_ge | 0 ...on313-test_collections-TestCounter.test_gt | 0 ...llections-TestCounter.test_helper_function | 0 ...on313-test_collections-TestCounter.test_le | 0 ...on313-test_collections-TestCounter.test_lt | 0 ...13-test_collections-TestCounter.test_unary | 0 torch/_dynamo/polyfills/__init__.py | 1 + torch/_dynamo/polyfills/_collections.py | 27 +++++++++++++++++++ torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/variables/builder.py | 4 ++- torch/_dynamo/variables/builtin.py | 9 ++++++- torch/_dynamo/variables/user_defined.py | 21 ++++++++++++--- 13 files changed, 58 insertions(+), 5 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_eq delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_ge delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_gt delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_helper_function delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_le delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_lt delete mode 100644 test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_unary create mode 100644 torch/_dynamo/polyfills/_collections.py diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_eq b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_eq deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_ge b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_ge deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_gt b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_gt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_helper_function b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_helper_function deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_le b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_le deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_lt b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_lt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_unary b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_unary deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 07b005e736e8..4fc777ffe7ef 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -24,6 +24,7 @@ # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py # Put the submodules here to avoid circular imports from . import ( + _collections as _collections, builtins as builtins, functools as functools, itertools as itertools, diff --git a/torch/_dynamo/polyfills/_collections.py b/torch/_dynamo/polyfills/_collections.py new file mode 100644 index 000000000000..41083a70d376 --- /dev/null +++ b/torch/_dynamo/polyfills/_collections.py @@ -0,0 +1,27 @@ +""" +Python polyfills for builtins +""" + +from ..decorators import substitute_in_graph +from typing import MutableMapping, Iterable, Any + + +__all__ = [] + +try: + import _collections + + @substitute_in_graph(_collections._count_elements) + def _count_elements( + mapping: MutableMapping[Any, int], + iterable: Iterable[Any], + ) -> None: + 'Tally elements from the iterable.' + mapping_get = mapping.get + for elem in iterable: + mapping[elem] = mapping_get(elem, 0) + 1 + + __all__.append("_count_elements") + +except ImportError: + pass \ No newline at end of file diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index f306d47ba5f8..d348a422ff57 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -13,6 +13,7 @@ # See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( + "_collections", "builtins", "functools", "itertools", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 08c2d6b14eec..92dd989e14a6 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3583,7 +3583,9 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker: if trace_rules.is_callable_allowed(value): tx.output.has_user_defined_allowed_in_graph = True return trace_rules.lookup_callable(value)(value) - elif callable(value) and UserDefinedClassVariable.is_supported_new_method(value): + elif callable(value) and UserDefinedClassVariable.is_supported_new_method( + value + ): # NamedTuple._make uses an alias of tuple.__new__ obj = trace_rules.lookup_callable(value.__self__)(value.__self__) return GetAttrVariable(obj, "__new__") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 628226b4228a..a8305adbd87d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2532,6 +2532,13 @@ def call_neg(self, tx: "InstructionTranslator", a): (operator.neg)(a.as_proxy()), sym_num=None, ) + + if ( + isinstance(a, UserDefinedObjectVariable) + and a.call_obj_hasattr(tx, "__neg__").value + ): + return a.call_method(tx, "__neg__", (), {}) + # None no-ops this handler and lets the driving function proceed return None @@ -2771,7 +2778,7 @@ def call_ior(self, tx: "InstructionTranslator", a, b): DictKeysVariable, MutableMappingVariable, SetVariable, - UserDefinedSetVariable + UserDefinedSetVariable, ), ): return a.call_method(tx, "__ior__", [b], {}) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 9549f1edcbda..209b08770f90 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -52,6 +52,7 @@ from ..exc import ( handle_observed_exception, ObservedAttributeError, + ObservedKeyError, raise_observed_exception, unimplemented_v2, ) @@ -1790,6 +1791,18 @@ def call_method( ) -> "VariableTracker": method = self._maybe_get_baseclass_method(name) if method in self._dict_methods: + # Dict subclasses can override __missing__ to provide fallback + # behavior instead of raising a KeyError. This is used, for example, + # by collections.Counter. + if ( + name == "__getitem__" + and issubclass(self.python_type(), dict) + and self.call_obj_hasattr(tx, "__missing__").value + ): + try: + return self._dict_vt.call_method(tx, name, args, kwargs) + except ObservedKeyError: + return self.call_method(tx, "__missing__", args, kwargs) return self._dict_vt.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) @@ -2019,9 +2032,11 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke def unpack_var_sequence(self, tx): # This shouldn't be necessary if iter(...) is implemented correctly # return super().unpack_var_sequence(tx) - return variables.UserFunctionVariable( - polyfills.builtins.iter_ - ).call_function(tx, [self], {}).items + return ( + variables.UserFunctionVariable(polyfills.builtins.iter_) + .call_function(tx, [self], {}) + .items + ) class RandomVariable(UserDefinedObjectVariable): From 5a9c7fc7732342f2d48613dad3ca728d9827c6d3 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 31 Jul 2025 23:25:16 -0300 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torch/_dynamo/polyfills/_collections.py | 10 ++++++---- torch/_dynamo/variables/builtin.py | 14 +++++++++----- torch/_dynamo/variables/user_defined.py | 19 ++++++++++--------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/torch/_dynamo/polyfills/_collections.py b/torch/_dynamo/polyfills/_collections.py index 41083a70d376..d9177b4ba66d 100644 --- a/torch/_dynamo/polyfills/_collections.py +++ b/torch/_dynamo/polyfills/_collections.py @@ -2,21 +2,23 @@ Python polyfills for builtins """ +from collections.abc import Iterable, MutableMapping +from typing import Any + from ..decorators import substitute_in_graph -from typing import MutableMapping, Iterable, Any __all__ = [] try: - import _collections + import _collections # type: ignore[import-not-found] @substitute_in_graph(_collections._count_elements) def _count_elements( mapping: MutableMapping[Any, int], iterable: Iterable[Any], ) -> None: - 'Tally elements from the iterable.' + "Tally elements from the iterable." mapping_get = mapping.get for elem in iterable: mapping[elem] = mapping_get(elem, 0) + 1 @@ -24,4 +26,4 @@ def _count_elements( __all__.append("_count_elements") except ImportError: - pass \ No newline at end of file + pass diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 0b22ff39fe53..16e004c49e85 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1373,9 +1373,13 @@ def call_method( if ( self.fn is tuple and len(args) == 2 - and tx.inline_user_function_return( - VariableTracker.build(tx, polyfills.is_iterable), [args[1]], {} - ).value + and ( + is_iterable := tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.is_iterable), [args[1]], {} + ) + ) + and isinstance(is_iterable, ConstantVariable) + and is_iterable.value and not kwargs ): if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: @@ -2534,9 +2538,9 @@ def call_neg(self, tx: "InstructionTranslator", a): if ( isinstance(a, UserDefinedObjectVariable) - and a.call_obj_hasattr(tx, "__neg__").value + and a.call_obj_hasattr(tx, "__neg__").value # type: ignore[attr-defined] ): - return a.call_method(tx, "__neg__", (), {}) + return a.call_method(tx, "__neg__", [], {}) # None no-ops this handler and lets the driving function proceed return None diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ee8664cba68b..7887ad0c82be 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1794,16 +1794,17 @@ def call_method( # Dict subclasses can override __missing__ to provide fallback # behavior instead of raising a KeyError. This is used, for example, # by collections.Counter. - if ( - name == "__getitem__" - and issubclass(self.python_type(), dict) - and self.call_obj_hasattr(tx, "__missing__").value - ): - try: - return self._dict_vt.call_method(tx, name, args, kwargs) - except ObservedKeyError: + try: + return self._dict_vt.call_method(tx, name, args, kwargs) + except ObservedKeyError: + if ( + name == "__getitem__" + and issubclass(self.python_type(), dict) + and self._maybe_get_baseclass_method("__missing__") + ): return self.call_method(tx, "__missing__", args, kwargs) - return self._dict_vt.call_method(tx, name, args, kwargs) + else: + raise return super().call_method(tx, name, args, kwargs) def unpack_var_sequence(self, tx):