diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index 25c630e5a59f..06a7e8051a4d 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -55,6 +55,49 @@ def f(x): f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 1) + def test_optimizer_whitelist(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(4, 4) + self.attr = torch.randn(4) + + def forward(self, x): + return self.lin(x) + self.attr + + f = Foo() + opt = torch.optim.Adam(f.parameters(), lr=1e-3) + + class Trainer: + def __init__(self, mod, opt): + self.mod = mod + self.opt = opt + + @torch.compile(fullgraph=False) + def loop(self, x): + self.opt.zero_grad() + out = self.mod(x) + loss = out.sum() + loss.backward() + self.opt.step() + return out + + trainer = Trainer(f, opt) + trainer.loop(torch.randn(2, 4)) + trainer.loop(torch.randn(4, 4)) + f.lin = torch.nn.Linear(8, 8) + f.attr = torch.randn(8) + trainer.opt = torch.optim.Adam(f.parameters(), lr=1e-3) + trainer.loop(torch.randn(8, 8)) + + # check optimizer tensors not in whitelist + state = torch._dynamo.pgo.render_code_state( + torch._dynamo.pgo.get_code_state() + ) + whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1) + for name in whitelist.split(","): + self.assertFalse("param_groups" in name) + def test_whitelist_suggestion(self): cnts = CompileCounter() diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index af1ac18a43ce..1ebcf11ceccb 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -23,7 +23,7 @@ import zlib from collections import defaultdict from typing import Optional, TYPE_CHECKING, TypeVar, Union -from typing_extensions import override, Self +from typing_extensions import is_protocol, override, Self import torch._dynamo.config import torch._utils_internal @@ -234,6 +234,7 @@ class FrameStateSizeEntry: stride: Union[ AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...] ] = dataclasses.field(default=auto_unset) + is_optimizer: bool = False def render(self) -> str: # Special cases @@ -311,25 +312,27 @@ def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]: return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs) @classmethod - def make_scalar(cls, x: int) -> FrameStateSizeEntry: - return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic) + def make_scalar(cls, x: int, is_optimizer: bool = False) -> FrameStateSizeEntry: + return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic, is_optimizer=is_optimizer) @classmethod def make_tensor( - cls, size: tuple[int, ...], stride: tuple[int, ...] + cls, size: tuple[int, ...], stride: tuple[int, ...], is_optimizer: bool = False ) -> FrameStateSizeEntry: return FrameStateSizeEntry( scalar=auto_dynamic, size=cls._munge_symint(size), stride=cls._munge_symint(stride), + is_optimizer=is_optimizer, ) @classmethod - def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry: + def make_size(cls, size: tuple[int, ...], is_optimizer: bool = False) -> FrameStateSizeEntry: return FrameStateSizeEntry( scalar=auto_unset, size=cls._munge_symint(size), stride=auto_unset, + is_optimizer=is_optimizer, ) @staticmethod @@ -362,6 +365,7 @@ def __ior__(self, other: Self) -> Self: self.scalar = self._merge_atom(self.scalar, other.scalar) self.size = self._merge_atom_tup(self.size, other.size) self.stride = self._merge_atom_tup(self.stride, other.stride) + self.is_optimizer |= other.is_optimizer return self @@ -597,6 +601,8 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: for k, v in cs.items(): cs_terms: list[str] = [] for src, fs in v.automatic_dynamic.items(): + if fs.is_optimizer: + continue cs_terms.append(f" {src}: {fs.render()}") if isinstance(fs.size, tuple) and auto_dynamic in fs.size: # type: ignore[operator] dynamic_sources.add(src) @@ -604,8 +610,8 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: code_state_str = "\n".join(terms) if dynamic_sources: code_state_str += ( - "\n\nPGO detected changes a recompilation due to tensor sizes. " - "To potentially avoid thisTo reduce shape recompilations by compiling dynamically to start, " + "\n\nPGO detected a recompilation due to tensor sizes. " + "To reduce shape recompilations by compiling dynamically to start, " f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"' ) with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index afa1bc083076..50ad40039fd0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2993,7 +2993,7 @@ def is_dynamic_source(source_name: str) -> bool: def record_automatic_dynamic( - tx: "InstructionTranslator", name: str, e: torch.Tensor + tx: "InstructionTranslator", name: str, source: Source, e: torch.Tensor ) -> FrameStateSizeEntry: # This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset ex_size = e.size() @@ -3013,7 +3013,7 @@ def record_automatic_dynamic( stride = [] return process_automatic_dynamic( - tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride)) + tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride), is_optimizer=is_from_optimizer_source(source)) ) @@ -3110,7 +3110,7 @@ def _automatic_dynamic( ) if static_shapes and not is_dynamic_source(name): - record_automatic_dynamic(tx, name, e) + record_automatic_dynamic(tx, name, source, e) return StatefulSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * e.dim(), dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), @@ -3140,7 +3140,7 @@ def _automatic_dynamic( ) # Prep for automatic dynamic - frame_state_entry = record_automatic_dynamic(tx, name, e) + frame_state_entry = record_automatic_dynamic(tx, name, source, e) # TODO: index export_constraints ahead of time so we don't have to # do a linear scan every time here