diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 6e28264d54669..31505b9445d40 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4136,6 +4136,7 @@ def func(): self.assertEqual(cnts.frame_count, 3) self.assertEqual(cnts.op_count, 6) + @torch._dynamo.config.patch(assume_dunder_attributes_remain_unchanged=False) def test_meth_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 8a66c847b52a1..27401f36e02f6 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1,5 +1,7 @@ # Owner(s): ["module: dynamo"] +import abc import functools +import inspect import unittest import weakref @@ -1150,21 +1152,32 @@ def hook(guard_wrapper, f_locals, builder): def test_nn_module_tag_safe(self): class Foo(torch.nn.Module): + c = 2 + def __init__(self): super().__init__() self.a = 4 + def check(self, x): + return True + def forward(self, x): - return x + self.a + inspect.signature(self.check).parameters.items() + return x + self.a + self.c foo = Foo() - class Baz(torch.nn.Module): + class Env(metaclass=abc.ABCMeta): # noqa: B024 + pass + + class Baz(torch.nn.Module, Env): def __init__(self): super().__init__() self.foo = foo def forward(self, x): + if "Foo" in str(type(self).__mro__): + x = torch.sin(x) return self.foo(x) baz = Baz() @@ -1179,7 +1192,6 @@ def fn(x): from utils import install_guard_manager_testing_hook def hook(guard_wrapper, f_locals, builder): - from torch._C._dynamo.guards import GetGenericDictGuardAccessor from torch._dynamo.source import LocalSource baz_source = LocalSource("baz") @@ -1189,27 +1201,6 @@ def hook(guard_wrapper, f_locals, builder): self.assertTrue(baz_mgr.is_tag_safe()) self.assertTrue(baz_mgr.is_tag_safe_root()) - # Check tagness of baz.__dict__ - self.assertTrue(len(baz_mgr.get_accessors()) == 1) - dunder_dict_accessor = baz_mgr.get_accessors()[0] - self.assertTrue( - isinstance(dunder_dict_accessor, GetGenericDictGuardAccessor) - ) - - dunder_dict_mgr = baz_mgr.get_child_managers()[0] - self.assertTrue(dunder_dict_mgr.is_tag_safe()) - self.assertFalse(dunder_dict_mgr.is_tag_safe_root()) - - # Check tagness of baz.__dict__["_modules"] - modules_mgr = dunder_dict_mgr.get_child_managers()[0] - self.assertTrue(modules_mgr.is_tag_safe()) - self.assertFalse(modules_mgr.is_tag_safe_root()) - - # Check tagness of baz.__dict__["_modules"]["foo"] - modules_foo_mgr = modules_mgr.get_child_managers()[0] - self.assertTrue(modules_foo_mgr.is_tag_safe()) - self.assertFalse(modules_foo_mgr.is_tag_safe_root()) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) with install_guard_manager_testing_hook(hook): opt_fn(torch.randn(4, 4)) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 5e0a014e8f784..64800504f4795 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -354,6 +354,12 @@ class DictGetItemGuardAccessor(GuardAccessor): ... class GetGenericDictGuardAccessor(GuardAccessor): ... class TypeDictGuardAccessor(GuardAccessor): ... class TypeMROGuardAccessor(GuardAccessor): ... +class ClosureGuardAccessor(GuardAccessor): ... +class TupleGetItemGuardAccessor(GuardAccessor): ... +class TypeGuardAccessor(GuardAccessor): ... +class CodeGuardAccessor(GuardAccessor): ... +class FuncDefaultsGuardAccessor(GuardAccessor): ... +class FuncKwDefaultsGuardAccessor(GuardAccessor): ... class GetAttrGuardAccessor(GuardAccessor): def get_attr_name(self) -> str: ... diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0d83b7078eae9..b8b7561dde16b 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -354,6 +354,25 @@ # Skips guards on func.__defaults__ if the element to be guarded is a constant skip_guards_on_constant_func_defaults = True + +# The recursive-dict-tag guard relies on the class/function identity staying +# stable. We therefore assume that the following function dunder attributes +# are **never rebound** to a different object: +# +# • __code__ • __closure__ +# • __defaults__ • __kwdefaults__ +# • __annotations__ • __mro__ +# +# It is fine to mutate the objects they already point to (e.g. tweak an element +# inside __defaults__), but assignments like +# +# foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED +# +# would invalidate the optimization. This type of rebinding is rare, so we +# assume that the rebinding never happens for guard purposes. Set the flag +# below to False only in environments where such rebinding is known to occur. +assume_dunder_attributes_remain_unchanged = True + # Speedup guard execution of nested nn modules by recursively checking for dict # tags to avoid full guard execution. use_recursive_dict_tags_for_guards = True diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a32b8d686dac7..445224319b970 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -48,10 +48,16 @@ from torch._C._dynamo.guards import ( check_obj_id, check_type_id, + ClosureGuardAccessor, + CodeGuardAccessor, dict_version, DictGetItemGuardAccessor, DictGuardManager, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, GetGenericDictGuardAccessor, + GuardAccessor, GuardDebugInfo, GuardManager, install_no_tensor_aliasing_guard, @@ -62,6 +68,10 @@ profile_guard_manager, RelationalGuard, RootGuardManager, + TupleGetItemGuardAccessor, + TypeDictGuardAccessor, + TypeGuardAccessor, + TypeMROGuardAccessor, ) from torch._dynamo.source import ( get_global_source_name, @@ -204,6 +214,17 @@ verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") +dunder_attrs_assumed_constants = ( + "__defaults__", + "__kwdefaults__", + "__code__", + "__closure__", + "__annotations__", + "__func__", + "__mro__", +) + + class IndentedBufferWithPrefix(IndentedBuffer): def prefix(self) -> str: return "| " * (self._indent * self.tabwidth) @@ -372,6 +393,16 @@ def find_tag_safe_roots(self) -> None: subset that are tag safe roots. """ + def check_tag_safety( + node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...] + ) -> bool: + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + return all( + isinstance(accessor, accepted_accessors) and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: # Just recurse through the key and value dict managers and check if # all of them are tag safe nodes. @@ -429,12 +460,8 @@ def visit_manager(node: GuardManager) -> list[GuardManager]: if is_subtree_tag_safe: node.mark_tag_safe() elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance(accessor, GetGenericDictGuardAccessor) - and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) + is_subtree_tag_safe = check_tag_safety( + node, (GetGenericDictGuardAccessor, TypeGuardAccessor) ) if is_subtree_tag_safe: node.mark_tag_safe() @@ -443,6 +470,77 @@ def visit_manager(node: GuardManager) -> list[GuardManager]: return [ node, ] + elif ( + node.get_type_of_guarded_value() + in ( + types.FunctionType, + types.MethodType, + staticmethod, + classmethod, + ) + and config.assume_dunder_attributes_remain_unchanged + ): + # Assumption: callers will not reassignthe attributes + # func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__. + # Mutating the objects those attributes point to is fine; + # rebinding the attribute itself is not. + # Example ─ allowed: foo.__defaults__[0].bar = 99 + # forbidden: foo.__defaults__ = (3, 4) + is_subtree_tag_safe = check_tag_safety( + node, + ( + CodeGuardAccessor, + ClosureGuardAccessor, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, + ), + ) + + for accessor in node.get_accessors(): + if isinstance(accessor, GetAttrGuardAccessor): + is_subtree_tag_safe &= ( + accessor.get_attr_name() in dunder_attrs_assumed_constants + ) + + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), types.CellType): + is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,)) + + is_subtree_tag_safe &= all( + isinstance(accessor, GetAttrGuardAccessor) + and accessor.get_attr_name() == "cell_contents" + for accessor in node.get_accessors() + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif ( + issubclass(node.get_type_of_guarded_value(), tuple) + and node.get_source().endswith(dunder_attrs_assumed_constants) + and config.assume_dunder_attributes_remain_unchanged + ): + # We trust tuples obtained from a function’s __closure__ or + # __defaults__. Any *other* tuple-valued attribute can be + # silently replaced—for example: + # + # foo.bar = (1, 2) # original + # foo.bar = (3, 4) # rebinding that our dict-tag optimisation won’t see + # + # Therefore only tuples from __closure__ / __defaults__ participate in the + # recursive-dict-tag optimization; all others are ignored. + is_subtree_tag_safe = check_tag_safety( + node, (TupleGetItemGuardAccessor,) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), type): + is_subtree_tag_safe = check_tag_safety( + node, (TypeDictGuardAccessor, TypeMROGuardAccessor) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + return tag_safe_roots def visit(node: GuardManager) -> list[GuardManager]: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 050f39f55895c..4bdcecf3b3c2c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1066,6 +1066,18 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj self.source_fn = source_fn + # Note on source and source_fn + # Be careful with `source` when delegating to UserFunctionVariable + # (base-class) methods. In this __init__, `source` is a *bound method* + # object, but the base class expects the underlying *function* object. + # One way is to simplly use `__func__` to unwrap it. + # + # For recursive dict-tag optimizations, it can be faster to fetch the + # function directly from `cls.__dict__`; that’s why we pass on + # `source_fn`. Whenever it is possible to access the function from + # cls.__dict__, we pass that on to `source_fn`. Because bind_args + # operates on the unbound function, most guards should target + # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: self.source_fn = AttrSource(kwargs.get("source"), "__func__") diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 95b1a37b677fc..084a1e2149d04 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -253,6 +253,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke elif name == "__dict__": options = {"source": source} return variables.GetAttrVariable(self, name, **options) + elif name == "__mro__": + attr_source = self.source and TypeMROSource(self.source) + return VariableTracker.build(tx, self.value.__mro__, attr_source) # Special handling of collections.OrderedDict.fromkeys() # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with @@ -295,10 +298,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke func = obj.__get__(None, self.value) return VariableTracker.build(tx, func, source) elif source: - # __mro__ is a member in < 3.12, an attribute in >= 3.12 - if inspect.ismemberdescriptor(obj) or ( - sys.version_info >= (3, 12) and name == "__mro__" - ): + if inspect.ismemberdescriptor(obj): return VariableTracker.build(tx, obj.__get__(self.value), source) if ConstantVariable.is_literal(obj):