Skip to content

[dynamo] fixes to propagate tag safeness #159807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
1 change: 1 addition & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4136,6 +4136,7 @@ def func():
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)

@torch._dynamo.config.patch(assume_dunder_attributes_remain_unchanged=False)
def test_meth_default_tensor_args(self):
"""
Tests that we indeed reference (and mutate) "the one" default tensor arg
Expand Down
39 changes: 15 additions & 24 deletions test/dynamo/test_guard_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Owner(s): ["module: dynamo"]
import abc
import functools
import inspect
import unittest
import weakref

Expand Down Expand Up @@ -1150,21 +1152,32 @@ def hook(guard_wrapper, f_locals, builder):

def test_nn_module_tag_safe(self):
class Foo(torch.nn.Module):
c = 2

def __init__(self):
super().__init__()
self.a = 4

def check(self, x):
return True

def forward(self, x):
return x + self.a
inspect.signature(self.check).parameters.items()
return x + self.a + self.c

foo = Foo()

class Baz(torch.nn.Module):
class Env(metaclass=abc.ABCMeta): # noqa: B024
pass

class Baz(torch.nn.Module, Env):
def __init__(self):
super().__init__()
self.foo = foo

def forward(self, x):
if "Foo" in str(type(self).__mro__):
x = torch.sin(x)
return self.foo(x)

baz = Baz()
Expand All @@ -1179,7 +1192,6 @@ def fn(x):
from utils import install_guard_manager_testing_hook

def hook(guard_wrapper, f_locals, builder):
from torch._C._dynamo.guards import GetGenericDictGuardAccessor
from torch._dynamo.source import LocalSource

baz_source = LocalSource("baz")
Expand All @@ -1189,27 +1201,6 @@ def hook(guard_wrapper, f_locals, builder):
self.assertTrue(baz_mgr.is_tag_safe())
self.assertTrue(baz_mgr.is_tag_safe_root())

# Check tagness of baz.__dict__
self.assertTrue(len(baz_mgr.get_accessors()) == 1)
dunder_dict_accessor = baz_mgr.get_accessors()[0]
self.assertTrue(
isinstance(dunder_dict_accessor, GetGenericDictGuardAccessor)
)

dunder_dict_mgr = baz_mgr.get_child_managers()[0]
self.assertTrue(dunder_dict_mgr.is_tag_safe())
self.assertFalse(dunder_dict_mgr.is_tag_safe_root())

# Check tagness of baz.__dict__["_modules"]
modules_mgr = dunder_dict_mgr.get_child_managers()[0]
self.assertTrue(modules_mgr.is_tag_safe())
self.assertFalse(modules_mgr.is_tag_safe_root())

# Check tagness of baz.__dict__["_modules"]["foo"]
modules_foo_mgr = modules_mgr.get_child_managers()[0]
self.assertTrue(modules_foo_mgr.is_tag_safe())
self.assertFalse(modules_foo_mgr.is_tag_safe_root())

opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
with install_guard_manager_testing_hook(hook):
opt_fn(torch.randn(4, 4))
Expand Down
6 changes: 6 additions & 0 deletions torch/_C/_dynamo/guards.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@ class DictGetItemGuardAccessor(GuardAccessor): ...
class GetGenericDictGuardAccessor(GuardAccessor): ...
class TypeDictGuardAccessor(GuardAccessor): ...
class TypeMROGuardAccessor(GuardAccessor): ...
class ClosureGuardAccessor(GuardAccessor): ...
class TupleGetItemGuardAccessor(GuardAccessor): ...
class TypeGuardAccessor(GuardAccessor): ...
class CodeGuardAccessor(GuardAccessor): ...
class FuncDefaultsGuardAccessor(GuardAccessor): ...
class FuncKwDefaultsGuardAccessor(GuardAccessor): ...

class GetAttrGuardAccessor(GuardAccessor):
def get_attr_name(self) -> str: ...
Expand Down
19 changes: 19 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,25 @@
# Skips guards on func.__defaults__ if the element to be guarded is a constant
skip_guards_on_constant_func_defaults = True


# The recursive-dict-tag guard relies on the class/function identity staying
# stable. We therefore assume that the following function dunder attributes
# are **never rebound** to a different object:
#
# • __code__ • __closure__
# • __defaults__ • __kwdefaults__
# • __annotations__ • __mro__
#
# It is fine to mutate the objects they already point to (e.g. tweak an element
# inside __defaults__), but assignments like
#
# foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED
#
# would invalidate the optimization. This type of rebinding is rare, so we
# assume that the rebinding never happens for guard purposes. Set the flag
# below to False only in environments where such rebinding is known to occur.
assume_dunder_attributes_remain_unchanged = True

# Speedup guard execution of nested nn modules by recursively checking for dict
# tags to avoid full guard execution.
use_recursive_dict_tags_for_guards = True
Expand Down
110 changes: 104 additions & 6 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@
from torch._C._dynamo.guards import (
check_obj_id,
check_type_id,
ClosureGuardAccessor,
CodeGuardAccessor,
dict_version,
DictGetItemGuardAccessor,
DictGuardManager,
FuncDefaultsGuardAccessor,
FuncKwDefaultsGuardAccessor,
GetAttrGuardAccessor,
GetGenericDictGuardAccessor,
GuardAccessor,
GuardDebugInfo,
GuardManager,
install_no_tensor_aliasing_guard,
Expand All @@ -62,6 +68,10 @@
profile_guard_manager,
RelationalGuard,
RootGuardManager,
TupleGetItemGuardAccessor,
TypeDictGuardAccessor,
TypeGuardAccessor,
TypeMROGuardAccessor,
)
from torch._dynamo.source import (
get_global_source_name,
Expand Down Expand Up @@ -204,6 +214,17 @@
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")


dunder_attrs_assumed_constants = (
"__defaults__",
"__kwdefaults__",
"__code__",
"__closure__",
"__annotations__",
"__func__",
"__mro__",
)


class IndentedBufferWithPrefix(IndentedBuffer):
def prefix(self) -> str:
return "| " * (self._indent * self.tabwidth)
Expand Down Expand Up @@ -372,6 +393,16 @@ def find_tag_safe_roots(self) -> None:
subset that are tag safe roots.
"""

def check_tag_safety(
node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...]
) -> bool:
accessors = node.get_accessors()
child_mgrs = node.get_child_managers()
return all(
isinstance(accessor, accepted_accessors) and mgr.is_tag_safe()
for accessor, mgr in zip(accessors, child_mgrs)
)

def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]:
# Just recurse through the key and value dict managers and check if
# all of them are tag safe nodes.
Expand Down Expand Up @@ -429,12 +460,8 @@ def visit_manager(node: GuardManager) -> list[GuardManager]:
if is_subtree_tag_safe:
node.mark_tag_safe()
elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
accessors = node.get_accessors()
child_mgrs = node.get_child_managers()
is_subtree_tag_safe = all(
isinstance(accessor, GetGenericDictGuardAccessor)
and mgr.is_tag_safe()
for accessor, mgr in zip(accessors, child_mgrs)
is_subtree_tag_safe = check_tag_safety(
node, (GetGenericDictGuardAccessor, TypeGuardAccessor)
)
if is_subtree_tag_safe:
node.mark_tag_safe()
Expand All @@ -443,6 +470,77 @@ def visit_manager(node: GuardManager) -> list[GuardManager]:
return [
node,
]
elif (
node.get_type_of_guarded_value()
in (
types.FunctionType,
types.MethodType,
staticmethod,
classmethod,
)
and config.assume_dunder_attributes_remain_unchanged
):
# Assumption: callers will not reassignthe attributes
# func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__.
# Mutating the objects those attributes point to is fine;
# rebinding the attribute itself is not.
# Example ─ allowed: foo.__defaults__[0].bar = 99
# forbidden: foo.__defaults__ = (3, 4)
is_subtree_tag_safe = check_tag_safety(
node,
(
CodeGuardAccessor,
ClosureGuardAccessor,
FuncDefaultsGuardAccessor,
FuncKwDefaultsGuardAccessor,
GetAttrGuardAccessor,
),
)

for accessor in node.get_accessors():
if isinstance(accessor, GetAttrGuardAccessor):
is_subtree_tag_safe &= (
accessor.get_attr_name() in dunder_attrs_assumed_constants
)

if is_subtree_tag_safe:
node.mark_tag_safe()
elif issubclass(node.get_type_of_guarded_value(), types.CellType):
is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,))

is_subtree_tag_safe &= all(
isinstance(accessor, GetAttrGuardAccessor)
and accessor.get_attr_name() == "cell_contents"
for accessor in node.get_accessors()
)
if is_subtree_tag_safe:
node.mark_tag_safe()
elif (
issubclass(node.get_type_of_guarded_value(), tuple)
and node.get_source().endswith(dunder_attrs_assumed_constants)
and config.assume_dunder_attributes_remain_unchanged
):
# We trust tuples obtained from a function’s __closure__ or
# __defaults__. Any *other* tuple-valued attribute can be
# silently replaced—for example:
#
# foo.bar = (1, 2) # original
# foo.bar = (3, 4) # rebinding that our dict-tag optimisation won’t see
#
# Therefore only tuples from __closure__ / __defaults__ participate in the
# recursive-dict-tag optimization; all others are ignored.
is_subtree_tag_safe = check_tag_safety(
node, (TupleGetItemGuardAccessor,)
)
if is_subtree_tag_safe:
node.mark_tag_safe()
elif issubclass(node.get_type_of_guarded_value(), type):
is_subtree_tag_safe = check_tag_safety(
node, (TypeDictGuardAccessor, TypeMROGuardAccessor)
)
if is_subtree_tag_safe:
node.mark_tag_safe()

return tag_safe_roots

def visit(node: GuardManager) -> list[GuardManager]:
Expand Down
12 changes: 12 additions & 0 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,18 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None:
super().__init__(fn=fn, **kwargs)
self.obj = obj
self.source_fn = source_fn
# Note on source and source_fn
# Be careful with `source` when delegating to UserFunctionVariable
# (base-class) methods. In this __init__, `source` is a *bound method*
# object, but the base class expects the underlying *function* object.
# One way is to simplly use `__func__` to unwrap it.
#
# For recursive dict-tag optimizations, it can be faster to fetch the
# function directly from `cls.__dict__`; that’s why we pass on
# `source_fn`. Whenever it is possible to access the function from
# cls.__dict__, we pass that on to `source_fn`. Because bind_args
# operates on the unbound function, most guards should target
# `source_fn` rather than the original `source`.
if source_fn is None and kwargs.get("source") is not None:
self.source_fn = AttrSource(kwargs.get("source"), "__func__")

Expand Down
8 changes: 4 additions & 4 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
elif name == "__dict__":
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
elif name == "__mro__":
attr_source = self.source and TypeMROSource(self.source)
return VariableTracker.build(tx, self.value.__mro__, attr_source)

# Special handling of collections.OrderedDict.fromkeys()
# Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with
Expand Down Expand Up @@ -295,10 +298,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
func = obj.__get__(None, self.value)
return VariableTracker.build(tx, func, source)
elif source:
# __mro__ is a member in < 3.12, an attribute in >= 3.12
if inspect.ismemberdescriptor(obj) or (
sys.version_info >= (3, 12) and name == "__mro__"
):
if inspect.ismemberdescriptor(obj):
return VariableTracker.build(tx, obj.__get__(self.value), source)

if ConstantVariable.is_literal(obj):
Expand Down
Loading