diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_eq b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_eq deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_ge b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_ge deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_gt b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_gt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_helper_function b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_helper_function deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_le b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_le deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_lt b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_lt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_unary b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_unary deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 07b005e736e88..4fc777ffe7efd 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -24,6 +24,7 @@ # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py # Put the submodules here to avoid circular imports from . import ( + _collections as _collections, builtins as builtins, functools as functools, itertools as itertools, diff --git a/torch/_dynamo/polyfills/_collections.py b/torch/_dynamo/polyfills/_collections.py new file mode 100644 index 0000000000000..9773635ae3058 --- /dev/null +++ b/torch/_dynamo/polyfills/_collections.py @@ -0,0 +1,33 @@ +""" +Python polyfills for builtins +""" + +from collections.abc import Iterable, MutableMapping +from typing import TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = [] + + +T = TypeVar("T") + + +try: + import _collections # type: ignore[import-not-found] + + @substitute_in_graph(_collections._count_elements) + def _count_elements( + mapping: MutableMapping[T, int], + iterable: Iterable[T], + ) -> None: + "Tally elements from the iterable." + mapping_get = mapping.get + for elem in iterable: + mapping[elem] = mapping_get(elem, 0) + 1 + + __all__.append("_count_elements") + +except ImportError: + pass diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index f306d47ba5f8a..d348a422ff576 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -13,6 +13,7 @@ # See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( + "_collections", "builtins", "functools", "itertools", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index f35c325c72d00..ec23875566e75 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2540,6 +2540,13 @@ def call_neg(self, tx: "InstructionTranslator", a): (operator.neg)(a.as_proxy()), sym_num=None, ) + + if ( + isinstance(a, UserDefinedObjectVariable) + and a.call_obj_hasattr(tx, "__neg__").value # type: ignore[attr-defined] + ): + return a.call_method(tx, "__neg__", [], {}) + # None no-ops this handler and lets the driving function proceed return None diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 48665aecfbe59..d803eb59cea11 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -52,6 +52,7 @@ from ..exc import ( handle_observed_exception, ObservedAttributeError, + ObservedKeyError, raise_observed_exception, unimplemented_v2, ) @@ -1874,7 +1875,20 @@ def call_method( ) -> "VariableTracker": method = self._maybe_get_baseclass_method(name) if method in self._dict_methods: - return self._dict_vt.call_method(tx, name, args, kwargs) + # Dict subclasses can override __missing__ to provide fallback + # behavior instead of raising a KeyError. This is used, for example, + # by collections.Counter. + try: + return self._dict_vt.call_method(tx, name, args, kwargs) + except ObservedKeyError: + if ( + name == "__getitem__" + and issubclass(self.python_type(), dict) + and self._maybe_get_baseclass_method("__missing__") + ): + return self.call_method(tx, "__missing__", args, kwargs) + else: + raise return super().call_method(tx, name, args, kwargs) def unpack_var_sequence(self, tx):