From 46472d70b140cb5b11cf442a39d899db074e554f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 4 Aug 2025 15:29:51 -0700 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- torch/_dynamo/guards.py | 45 +++++++++++++++++++++++++++++++++++- torch/csrc/dynamo/guards.cpp | 2 +- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 50220f3e2329..177cf6ba7da3 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -49,9 +49,11 @@ from torch._C._dynamo.guards import ( check_obj_id, check_type_id, + ClosureGuardAccessor, dict_version, DictGetItemGuardAccessor, DictGuardManager, + GetAttrGuardAccessor, GetGenericDictGuardAccessor, install_no_tensor_aliasing_guard, install_object_aliasing_guard, @@ -59,6 +61,8 @@ install_symbolic_shape_guard, profile_guard_manager, RootGuardManager, + TupleGetItemGuardAccessor, + TypeGuardAccessor, ) from torch._dynamo.source import ( get_global_source_name, @@ -412,7 +416,9 @@ def visit_manager(node): accessors = node.get_accessors() child_mgrs = node.get_child_managers() is_subtree_tag_safe = all( - isinstance(accessor, GetGenericDictGuardAccessor) + isinstance( + accessor, (GetGenericDictGuardAccessor, TypeGuardAccessor) + ) and mgr.is_tag_safe() for accessor, mgr in zip(accessors, child_mgrs) ) @@ -423,6 +429,43 @@ def visit_manager(node): return [ node, ] + elif node.get_type_of_guarded_value() in ( + types.FunctionType, + types.MethodType, + ): + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + is_subtree_tag_safe = all( + isinstance(accessor, ClosureGuardAccessor) and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), types.CellType): + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + is_subtree_tag_safe = all( + isinstance(accessor, GetAttrGuardAccessor) and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + + is_subtree_tag_safe &= all( + accessor.get_attr_name() == "cell_contents" + for accessor in accessors + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), tuple): + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + is_subtree_tag_safe = all( + isinstance(accessor, TupleGetItemGuardAccessor) + and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + return tag_safe_roots def visit(node): diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 50f879c7f1a7..378df5d19dfd 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -1177,7 +1177,7 @@ bool is_immutable_object(py::handle example_value) { PyLong_Check(example_value.ptr()) || PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) || PyUnicode_Check(example_value.ptr()) || - PyCode_Check(example_value.ptr()) || + PyCode_Check(example_value.ptr()) || PyType_Check(example_value.ptr()) || (Py_TYPE(example_value.ptr()) == &PyCFunction_Type) || (is_tensor_immutable && THPVariable_Check(example_value.ptr())); } From 9459d11c100561c414b627492cca21055572e934 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 4 Aug 2025 23:38:25 -0700 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- test/dynamo/test_functions.py | 3 ++ torch/_C/_dynamo/guards.pyi | 6 +++ torch/_dynamo/config.py | 22 ++++++++++ torch/_dynamo/guards.py | 83 ++++++++++++++++++++++------------- torch/csrc/dynamo/guards.cpp | 2 +- 5 files changed, 85 insertions(+), 31 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 4afb6acc5d87..3d7ccabe8b99 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4094,6 +4094,9 @@ def func(): self.assertEqual(cnts.frame_count, 3) self.assertEqual(cnts.op_count, 6) + @torch._dynamo.config.patch( + assume_function_dunder_attributes_remain_unchanged=False + ) def test_meth_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 9c2c379ae589..c41c48aa1c7c 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -141,6 +141,12 @@ class DictGetItemGuardAccessor(GuardAccessor): ... class GetGenericDictGuardAccessor(GuardAccessor): ... class TypeDictGuardAccessor(GuardAccessor): ... class TypeMROGuardAccessor(GuardAccessor): ... +class ClosureGuardAccessor(GuardAccessor): ... +class TupleGetItemGuardAccessor(GuardAccessor): ... +class TypeGuardAccessor(GuardAccessor): ... +class CodeGuardAccessor(GuardAccessor): ... +class FuncDefaultsGuardAccessor(GuardAccessor): ... +class FuncKwDefaultsGuardAccessor(GuardAccessor): ... class GetAttrGuardAccessor(GuardAccessor): def get_attr_name(self) -> str: ... diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0d83b7078eae..b4d4cc749ee8 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -354,6 +354,28 @@ # Skips guards on func.__defaults__ if the element to be guarded is a constant skip_guards_on_constant_func_defaults = True + +# The recursive-dict-tag guard relies on the class/function identity staying +# stable. We therefore assume that the following function dunder attributes +# are **never rebound** to a different object: +# +# • __code__ • __closure__ +# • __defaults__ • __kwdefaults__ +# +# It is fine to mutate the objects they already point to (e.g. tweak an element +# inside __defaults__), but assignments like +# +# foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED +# +# would invalidate the optimization. This type of rebinding is rare, so we +# assume that the rebinding never happens for guard purposes. Set the flag +# below to False only in environments where such rebinding is known to occur. +assume_function_dunder_attributes_remain_unchanged = True + +# assume function dunder attributes will not be reaassigned to some other object +# - __code__, __closure__, __defaults__, __kwdefaults__ +assume_function_dunder_attributes_remain_unchanged = True + # Speedup guard execution of nested nn modules by recursively checking for dict # tags to avoid full guard execution. use_recursive_dict_tags_for_guards = True diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 177cf6ba7da3..4a31e4653e2c 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -50,9 +50,12 @@ check_obj_id, check_type_id, ClosureGuardAccessor, + CodeGuardAccessor, dict_version, DictGetItemGuardAccessor, DictGuardManager, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, GetAttrGuardAccessor, GetGenericDictGuardAccessor, install_no_tensor_aliasing_guard, @@ -356,6 +359,14 @@ def find_tag_safe_roots(self): subset that are tag safe roots. """ + def check_tag_safety(node, accepted_accessors): + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + return all( + isinstance(accessor, accepted_accessors) and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + def visit_dict_manager(node): # Just recurse through the key and value dict managers and check if # all of them are tag safe nodes. @@ -413,14 +424,8 @@ def visit_manager(node): if is_subtree_tag_safe: node.mark_tag_safe() elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance( - accessor, (GetGenericDictGuardAccessor, TypeGuardAccessor) - ) - and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) + is_subtree_tag_safe = check_tag_safety( + node, (GetGenericDictGuardAccessor, TypeGuardAccessor) ) if is_subtree_tag_safe: node.mark_tag_safe() @@ -429,40 +434,58 @@ def visit_manager(node): return [ node, ] - elif node.get_type_of_guarded_value() in ( - types.FunctionType, - types.MethodType, + elif ( + node.get_type_of_guarded_value() + in ( + types.FunctionType, + types.MethodType, + ) + and config.assume_function_dunder_attributes_remain_unchanged ): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance(accessor, ClosureGuardAccessor) and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) + # Assumption: callers will not reassignthe attributes + # func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__. + # Mutating the objects those attributes point to is fine; + # rebinding the attribute itself is not. + # Example ─ allowed: foo.__defaults__[0].bar = 99 + # forbidden: foo.__defaults__ = (3, 4) + is_subtree_tag_safe = check_tag_safety( + node, + ( + CodeGuardAccessor, + ClosureGuardAccessor, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + ), ) if is_subtree_tag_safe: node.mark_tag_safe() elif issubclass(node.get_type_of_guarded_value(), types.CellType): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance(accessor, GetAttrGuardAccessor) and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) - ) + is_subtree_tag_safe = check_tag_safety(node, GetAttrGuardAccessor) is_subtree_tag_safe &= all( accessor.get_attr_name() == "cell_contents" - for accessor in accessors + for accessor in node.get_accessors() ) if is_subtree_tag_safe: node.mark_tag_safe() - elif issubclass(node.get_type_of_guarded_value(), tuple): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance(accessor, TupleGetItemGuardAccessor) - and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) + elif ( + issubclass(node.get_type_of_guarded_value(), tuple) + and ( + "__closure__" in node.get_source() + or "__defaults__" in node.get_source() ) + and config.assume_function_dunder_attributes_remain_unchanged + ): + # We trust tuples obtained from a function’s __closure__ or + # __defaults__. Any *other* tuple-valued attribute can be + # silently replaced—for example: + # + # foo.bar = (1, 2) # original + # foo.bar = (3, 4) # rebinding that our dict-tag optimisation won’t see + # + # Therefore only tuples from __closure__ / __defaults__ participate in the + # recursive-dict-tag optimization; all others are ignored. + is_subtree_tag_safe = check_tag_safety(node, TupleGetItemGuardAccessor) if is_subtree_tag_safe: node.mark_tag_safe() diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 6061abda1ae1..ae7aa20be29c 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -1176,7 +1176,7 @@ bool is_immutable_object(py::handle example_value) { PyLong_Check(example_value.ptr()) || PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) || PyUnicode_Check(example_value.ptr()) || - PyCode_Check(example_value.ptr()) || PyType_Check(example_value.ptr()) || + PyCode_Check(example_value.ptr()) || (Py_TYPE(example_value.ptr()) == &PyCFunction_Type) || (is_tensor_immutable && THPVariable_Check(example_value.ptr())); } From 05a1333455cc020ec3c607c86a84576b55ab3870 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 5 Aug 2025 12:20:48 -0700 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- torch/_dynamo/guards.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4a31e4653e2c..dc0fbdc131d4 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -65,7 +65,9 @@ profile_guard_manager, RootGuardManager, TupleGetItemGuardAccessor, + TypeDictGuardAccessor, TypeGuardAccessor, + TypeMROGuardAccessor, ) from torch._dynamo.source import ( get_global_source_name, @@ -455,8 +457,19 @@ def visit_manager(node): ClosureGuardAccessor, FuncDefaultsGuardAccessor, FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, ), ) + + for accessor in node.get_accessors(): + if isinstance(accessor, GetAttrGuardAccessor): + is_subtree_tag_safe &= accessor.get_attr_name() in ( + "__defaults__", + "__kwdefaults__", + "__code__", + "__closure__", + ) + if is_subtree_tag_safe: node.mark_tag_safe() elif issubclass(node.get_type_of_guarded_value(), types.CellType): @@ -488,6 +501,12 @@ def visit_manager(node): is_subtree_tag_safe = check_tag_safety(node, TupleGetItemGuardAccessor) if is_subtree_tag_safe: node.mark_tag_safe() + elif node.get_type_of_guarded_value() is type: + is_subtree_tag_safe = check_tag_safety( + node, (TypeDictGuardAccessor, TypeMROGuardAccessor) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() return tag_safe_roots From 4f8cf041efac8ca9d9b24ffe13bb31268afb6a6f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 8 Aug 2025 14:00:37 -0700 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- test/dynamo/test_guard_manager.py | 4 +++- torch/_dynamo/variables/functions.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 8a66c847b52a..ecf9daf5b1a7 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1150,12 +1150,14 @@ def hook(guard_wrapper, f_locals, builder): def test_nn_module_tag_safe(self): class Foo(torch.nn.Module): + c = 2 + def __init__(self): super().__init__() self.a = 4 def forward(self, x): - return x + self.a + return x + self.a + self.c foo = Foo() diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 050f39f55895..4bdcecf3b3c2 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1066,6 +1066,18 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj self.source_fn = source_fn + # Note on source and source_fn + # Be careful with `source` when delegating to UserFunctionVariable + # (base-class) methods. In this __init__, `source` is a *bound method* + # object, but the base class expects the underlying *function* object. + # One way is to simplly use `__func__` to unwrap it. + # + # For recursive dict-tag optimizations, it can be faster to fetch the + # function directly from `cls.__dict__`; that’s why we pass on + # `source_fn`. Whenever it is possible to access the function from + # cls.__dict__, we pass that on to `source_fn`. Because bind_args + # operates on the unbound function, most guards should target + # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: self.source_fn = AttrSource(kwargs.get("source"), "__func__") From 1d1a3a33418fea414e341734bfa0cde0a64856e7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 9 Aug 2025 23:42:33 -0700 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- torch/_dynamo/guards.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index bc5c2b0d1892..445224319b97 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -475,6 +475,8 @@ def visit_manager(node: GuardManager) -> list[GuardManager]: in ( types.FunctionType, types.MethodType, + staticmethod, + classmethod, ) and config.assume_dunder_attributes_remain_unchanged ):